Kyopro Library
 
読み取り中…
検索中…
一致する文字列を見つけられません
fps2.hpp
[詳解]
1template <typename mint>
2struct NTT {
3 static constexpr uint32_t get_pr() {
4 uint32_t _mod = mint::get_mod();
5 using u64 = uint64_t;
6 u64 ds[32] = {};
7 int idx = 0;
8 u64 m = _mod - 1;
9 for (u64 i = 2; i * i <= m; ++i) {
10 if (m % i == 0) {
11 ds[idx++] = i;
12 while (m % i == 0) m /= i;
13 }
14 }
15 if (m != 1) ds[idx++] = m;
16
17 uint32_t _pr = 2;
18 while (1) {
19 int flg = 1;
20 for (int i = 0; i < idx; ++i) {
21 u64 a = _pr, b = (_mod - 1) / ds[i], r = 1;
22 while (b) {
23 if (b & 1) r = r * a % _mod;
24 a = a * a % _mod;
25 b >>= 1;
26 }
27 if (r == 1) {
28 flg = 0;
29 break;
30 }
31 }
32 if (flg == 1) break;
33 ++_pr;
34 }
35 return _pr;
36 };
37
38 static constexpr uint32_t mod = mint::get_mod();
39 static constexpr uint32_t pr = get_pr();
40 static constexpr int level = __builtin_ctzll(mod - 1);
41 mint dw[level], dy[level];
42
43 void setwy(int k) {
44 mint w[level], y[level];
45 w[k - 1] = mint(pr).pow((mod - 1) / (1 << k));
46 y[k - 1] = w[k - 1].inverse();
47 for (int i = k - 2; i > 0; --i)
48 w[i] = w[i + 1] * w[i + 1], y[i] = y[i + 1] * y[i + 1];
49 dw[1] = w[1], dy[1] = y[1], dw[2] = w[2], dy[2] = y[2];
50 for (int i = 3; i < k; ++i) {
51 dw[i] = dw[i - 1] * y[i - 2] * w[i];
52 dy[i] = dy[i - 1] * w[i - 2] * y[i];
53 }
54 }
55
57
58 void fft4(vector<mint> &a, int k) {
59 if ((int)a.size() <= 1) return;
60 if (k == 1) {
61 mint a1 = a[1];
62 a[1] = a[0] - a[1];
63 a[0] = a[0] + a1;
64 return;
65 }
66 if (k & 1) {
67 int v = 1 << (k - 1);
68 for (int j = 0; j < v; ++j) {
69 mint ajv = a[j + v];
70 a[j + v] = a[j] - ajv;
71 a[j] += ajv;
72 }
73 }
74 int u = 1 << (2 + (k & 1));
75 int v = 1 << (k - 2 - (k & 1));
76 mint one = mint(1);
77 mint imag = dw[1];
78 while (v) {
79 // jh = 0
80 {
81 int j0 = 0;
82 int j1 = v;
83 int j2 = j1 + v;
84 int j3 = j2 + v;
85 for (; j0 < v; ++j0, ++j1, ++j2, ++j3) {
86 mint t0 = a[j0], t1 = a[j1], t2 = a[j2], t3 = a[j3];
87 mint t0p2 = t0 + t2, t1p3 = t1 + t3;
88 mint t0m2 = t0 - t2, t1m3 = (t1 - t3) * imag;
89 a[j0] = t0p2 + t1p3, a[j1] = t0p2 - t1p3;
90 a[j2] = t0m2 + t1m3, a[j3] = t0m2 - t1m3;
91 }
92 }
93 // jh >= 1
94 mint ww = one, xx = one * dw[2], wx = one;
95 for (int jh = 4; jh < u;) {
96 ww = xx * xx, wx = ww * xx;
97 int j0 = jh * v;
98 int je = j0 + v;
99 int j2 = je + v;
100 for (; j0 < je; ++j0, ++j2) {
101 mint t0 = a[j0], t1 = a[j0 + v] * xx, t2 = a[j2] * ww,
102 t3 = a[j2 + v] * wx;
103 mint t0p2 = t0 + t2, t1p3 = t1 + t3;
104 mint t0m2 = t0 - t2, t1m3 = (t1 - t3) * imag;
105 a[j0] = t0p2 + t1p3, a[j0 + v] = t0p2 - t1p3;
106 a[j2] = t0m2 + t1m3, a[j2 + v] = t0m2 - t1m3;
107 }
108 xx *= dw[__builtin_ctzll((jh += 4))];
109 }
110 u <<= 2;
111 v >>= 2;
112 }
113 }
114
115 void ifft4(vector<mint> &a, int k) {
116 if ((int)a.size() <= 1) return;
117 if (k == 1) {
118 mint a1 = a[1];
119 a[1] = a[0] - a[1];
120 a[0] = a[0] + a1;
121 return;
122 }
123 int u = 1 << (k - 2);
124 int v = 1;
125 mint one = mint(1);
126 mint imag = dy[1];
127 while (u) {
128 // jh = 0
129 {
130 int j0 = 0;
131 int j1 = v;
132 int j2 = v + v;
133 int j3 = j2 + v;
134 for (; j0 < v; ++j0, ++j1, ++j2, ++j3) {
135 mint t0 = a[j0], t1 = a[j1], t2 = a[j2], t3 = a[j3];
136 mint t0p1 = t0 + t1, t2p3 = t2 + t3;
137 mint t0m1 = t0 - t1, t2m3 = (t2 - t3) * imag;
138 a[j0] = t0p1 + t2p3, a[j2] = t0p1 - t2p3;
139 a[j1] = t0m1 + t2m3, a[j3] = t0m1 - t2m3;
140 }
141 }
142 // jh >= 1
143 mint ww = one, xx = one * dy[2], yy = one;
144 u <<= 2;
145 for (int jh = 4; jh < u;) {
146 ww = xx * xx, yy = xx * imag;
147 int j0 = jh * v;
148 int je = j0 + v;
149 int j2 = je + v;
150 for (; j0 < je; ++j0, ++j2) {
151 mint t0 = a[j0], t1 = a[j0 + v], t2 = a[j2], t3 = a[j2 + v];
152 mint t0p1 = t0 + t1, t2p3 = t2 + t3;
153 mint t0m1 = (t0 - t1) * xx, t2m3 = (t2 - t3) * yy;
154 a[j0] = t0p1 + t2p3, a[j2] = (t0p1 - t2p3) * ww;
155 a[j0 + v] = t0m1 + t2m3, a[j2 + v] = (t0m1 - t2m3) * ww;
156 }
157 xx *= dy[__builtin_ctzll(jh += 4)];
158 }
159 u >>= 4;
160 v <<= 2;
161 }
162 if (k & 1) {
163 u = 1 << (k - 1);
164 for (int j = 0; j < u; ++j) {
165 mint ajv = a[j] - a[j + u];
166 a[j] += a[j + u];
167 a[j + u] = ajv;
168 }
169 }
170 }
171
172 void ntt(vector<mint> &a) {
173 if ((int)a.size() <= 1) return;
174 fft4(a, __builtin_ctz(a.size()));
175 }
176
177 void intt(vector<mint> &a) {
178 if ((int)a.size() <= 1) return;
179 ifft4(a, __builtin_ctz(a.size()));
180 mint iv = mint(a.size()).inverse();
181 for (auto &x : a) x *= iv;
182 }
183
184 vector<mint> multiply(const vector<mint> &a, const vector<mint> &b) {
185 int l = a.size() + b.size() - 1;
186 if (min<int>(a.size(), b.size()) <= 40) {
187 vector<mint> s(l);
188 for (int i = 0; i < (int)a.size(); ++i)
189 for (int j = 0; j < (int)b.size(); ++j) s[i + j] += a[i] * b[j];
190 return s;
191 }
192 int k = 2, M = 4;
193 while (M < l) M <<= 1, ++k;
194 setwy(k);
195 vector<mint> s(M);
196 for (int i = 0; i < (int)a.size(); ++i) s[i] = a[i];
197 fft4(s, k);
198 if (a.size() == b.size() && a == b) {
199 for (int i = 0; i < M; ++i) s[i] *= s[i];
200 } else {
201 vector<mint> t(M);
202 for (int i = 0; i < (int)b.size(); ++i) t[i] = b[i];
203 fft4(t, k);
204 for (int i = 0; i < M; ++i) s[i] *= t[i];
205 }
206 ifft4(s, k);
207 s.resize(l);
208 mint invm = mint(M).inverse();
209 for (int i = 0; i < l; ++i) s[i] *= invm;
210 return s;
211 }
212
213 void ntt_doubling(vector<mint> &a) {
214 int M = (int)a.size();
215 auto b = a;
216 intt(b);
217 mint r = 1, zeta = mint(pr).pow((mint::get_mod() - 1) / (M << 1));
218 for (int i = 0; i < M; i++) b[i] *= r, r *= zeta;
219 ntt(b);
220 copy(begin(b), end(b), back_inserter(a));
221 }
222};
223
224template <typename mint>
226 using vector<mint>::vector;
227 using FPS = FormalPowerSeries;
228
229 FPS &operator+=(const FPS &r) {
230 if (r.size() > this->size()) this->resize(r.size());
231 for (int i = 0; i < (int)r.size(); i++) (*this)[i] += r[i];
232 return *this;
233 }
234
235 FPS &operator+=(const mint &r) {
236 if (this->empty()) this->resize(1);
237 (*this)[0] += r;
238 return *this;
239 }
240
241 FPS &operator-=(const FPS &r) {
242 if (r.size() > this->size()) this->resize(r.size());
243 for (int i = 0; i < (int)r.size(); i++) (*this)[i] -= r[i];
244 return *this;
245 }
246
247 FPS &operator-=(const mint &r) {
248 if (this->empty()) this->resize(1);
249 (*this)[0] -= r;
250 return *this;
251 }
252
253 FPS &operator*=(const mint &v) {
254 for (int k = 0; k < (int)this->size(); k++) (*this)[k] *= v;
255 return *this;
256 }
257
258 FPS &operator/=(const FPS &r) {
259 if (this->size() < r.size()) {
260 this->clear();
261 return *this;
262 }
263 int n = this->size() - r.size() + 1;
264 if ((int)r.size() <= 64) {
265 FPS f(*this), g(r);
266 g.shrink();
267 mint coeff = g.back().inverse();
268 for (auto &x : g) x *= coeff;
269 int deg = (int)f.size() - (int)g.size() + 1;
270 int gs = g.size();
271 FPS quo(deg);
272 for (int i = deg - 1; i >= 0; i--) {
273 quo[i] = f[i + gs - 1];
274 for (int j = 0; j < gs; j++) f[i + j] -= quo[i] * g[j];
275 }
276 *this = quo * coeff;
277 this->resize(n, mint(0));
278 return *this;
279 }
280 return *this = ((*this).rev().pre(n) * r.rev().inv(n)).pre(n).rev();
281 }
282
283 FPS &operator%=(const FPS &r) {
284 *this -= *this / r * r;
285 shrink();
286 return *this;
287 }
288
289 FPS operator+(const FPS &r) const { return FPS(*this) += r; }
290 FPS operator+(const mint &v) const { return FPS(*this) += v; }
291 FPS operator-(const FPS &r) const { return FPS(*this) -= r; }
292 FPS operator-(const mint &v) const { return FPS(*this) -= v; }
293 FPS operator*(const FPS &r) const { return FPS(*this) *= r; }
294 FPS operator*(const mint &v) const { return FPS(*this) *= v; }
295 FPS operator/(const FPS &r) const { return FPS(*this) /= r; }
296 FPS operator%(const FPS &r) const { return FPS(*this) %= r; }
297 FPS operator-() const {
298 FPS ret(this->size());
299 for (int i = 0; i < (int)this->size(); i++) ret[i] = -(*this)[i];
300 return ret;
301 }
302
303 void shrink() {
304 while (this->size() && this->back() == mint(0)) this->pop_back();
305 }
306
307 FPS rev() const {
308 FPS ret(*this);
309 reverse(begin(ret), end(ret));
310 return ret;
311 }
312
313 FPS dot(FPS r) const {
314 FPS ret(min(this->size(), r.size()));
315 for (int i = 0; i < (int)ret.size(); i++) ret[i] = (*this)[i] * r[i];
316 return ret;
317 }
318
319 // 前 sz 項を取ってくる。sz に足りない項は 0 埋めする
320 FPS pre(int sz) const {
321 FPS ret(begin(*this), begin(*this) + min((int)this->size(), sz));
322 if ((int)ret.size() < sz) ret.resize(sz);
323 return ret;
324 }
325
326 FPS operator>>(int sz) const {
327 if ((int)this->size() <= sz) return {};
328 FPS ret(*this);
329 ret.erase(ret.begin(), ret.begin() + sz);
330 return ret;
331 }
332
333 FPS operator<<(int sz) const {
334 FPS ret(*this);
335 ret.insert(ret.begin(), sz, mint(0));
336 return ret;
337 }
338
339 FPS diff() const {
340 const int n = (int)this->size();
341 FPS ret(max(0, n - 1));
342 mint one(1), coeff(1);
343 for (int i = 1; i < n; i++) {
344 ret[i - 1] = (*this)[i] * coeff;
345 coeff += one;
346 }
347 return ret;
348 }
349
350 FPS integral() const {
351 const int n = (int)this->size();
352 FPS ret(n + 1);
353 ret[0] = mint(0);
354 if (n > 0) ret[1] = mint(1);
355 auto mod = mint::get_mod();
356 for (int i = 2; i <= n; i++) ret[i] = (-ret[mod % i]) * (mod / i);
357 for (int i = 0; i < n; i++) ret[i + 1] *= (*this)[i];
358 return ret;
359 }
360
361 mint eval(mint x) const {
362 mint r = 0, w = 1;
363 for (auto &v : *this) r += w * v, w *= x;
364 return r;
365 }
366
367 FPS log(int deg = -1) const {
368 assert(!(*this).empty() && (*this)[0] == mint(1));
369 if (deg == -1) deg = (int)this->size();
370 return (this->diff() * this->inv(deg)).pre(deg - 1).integral();
371 }
372
373 FPS pow(int64_t k, int deg = -1) const {
374 const int n = (int)this->size();
375 if (deg == -1) deg = n;
376 if (k == 0) {
377 FPS ret(deg);
378 if (deg) ret[0] = 1;
379 return ret;
380 }
381 for (int i = 0; i < n; i++) {
382 if ((*this)[i] != mint(0)) {
383 mint rev = mint(1) / (*this)[i];
384 FPS ret = (((*this * rev) >> i).log(deg) * k).exp(deg);
385 ret *= (*this)[i].pow(k);
386 ret = (ret << (i * k)).pre(deg);
387 if ((int)ret.size() < deg) ret.resize(deg, mint(0));
388 return ret;
389 }
390 if (__int128_t(i + 1) * k >= deg) return FPS(deg, mint(0));
391 }
392 return FPS(deg, mint(0));
393 }
394
395 static void *ntt_ptr;
396 static void set_fft();
397 FPS &operator*=(const FPS &r);
398 void ntt();
399 void intt();
400 void ntt_doubling();
401 static int ntt_pr();
402 FPS inv(int deg = -1) const;
403 FPS exp(int deg = -1) const;
404};
405template <typename mint>
406void *FormalPowerSeries<mint>::ntt_ptr = nullptr;
407
408/**
409 * @brief 多項式/形式的冪級数ライブラリ
410 * @docs docs/fps/formal-power-series.md
411 */
412#line 8 "fps/sparse-fps.hpp"
413
414// g が sparse を仮定, f * g.inv() を計算
415template <typename mint>
417 const FormalPowerSeries<mint>& g,
418 int deg = -1) {
419 assert(g.empty() == false && g[0] != mint(0));
420 if (deg == -1) deg = f.size();
421 mint ig0 = g[0].inverse();
422 FormalPowerSeries<mint> s = f * ig0;
423 s.resize(deg);
424 vector<pair<int, mint>> gs;
425 for (int i = 1; i < (int)g.size(); i++) {
426 if (g[i] != 0) gs.emplace_back(i, g[i] * ig0);
427 }
428 for (int i = 0; i < deg; i++) {
429 for (auto& [j, g_j] : gs) {
430 if (i + j >= deg) break;
431 s[i + j] -= s[i] * g_j;
432 }
433 }
434 return s;
435}
436
437template <typename mint>
439 int deg = -1) {
440 assert(f.empty() == false && f[0] != mint(0));
441 if (deg == -1) deg = f.size();
442 vector<pair<int, mint>> fs;
443 for (int i = 1; i < (int)f.size(); i++) {
444 if (f[i] != 0) fs.emplace_back(i, f[i]);
445 }
446 FormalPowerSeries<mint> g(deg);
447 mint if0 = f[0].inverse();
448 if (0 < deg) g[0] = if0;
449 for (int k = 1; k < deg; k++) {
450 for (auto& [j, fj] : fs) {
451 if (k < j) break;
452 g[k] += g[k - j] * fj;
453 }
454 g[k] *= -if0;
455 }
456 return g;
457}
458
459template <typename mint>
461 int deg = -1) {
462 assert(f.empty() == false && f[0] == 1);
463 if (deg == -1) deg = f.size();
464 vector<pair<int, mint>> fs;
465 for (int i = 1; i < (int)f.size(); i++) {
466 if (f[i] != 0) fs.emplace_back(i, f[i]);
467 }
468
469 int mod = mint::get_mod();
470 static vector<mint> invs{1, 1};
471 while ((int)invs.size() <= deg) {
472 int i = invs.size();
473 invs.push_back((-invs[mod % i]) * (mod / i));
474 }
475
476 FormalPowerSeries<mint> g(deg);
477 for (int k = 0; k < deg - 1; k++) {
478 for (auto& [j, fj] : fs) {
479 if (k < j) break;
480 int i = k - j;
481 g[k + 1] -= g[i + 1] * fj * (i + 1);
482 }
483 g[k + 1] *= invs[k + 1];
484 if (k + 1 < (int)f.size()) g[k + 1] += f[k + 1];
485 }
486 return g;
487}
488
489template <typename mint>
491 int deg = -1) {
492 assert(f.empty() or f[0] == 0);
493 if (deg == -1) deg = f.size();
494 vector<pair<int, mint>> fs;
495 for (int i = 1; i < (int)f.size(); i++) {
496 if (f[i] != 0) fs.emplace_back(i, f[i]);
497 }
498
499 int mod = mint::get_mod();
500 static vector<mint> invs{1, 1};
501 while ((int)invs.size() <= deg) {
502 int i = invs.size();
503 invs.push_back((-invs[mod % i]) * (mod / i));
504 }
505
506 FormalPowerSeries<mint> g(deg);
507 if (deg) g[0] = 1;
508 for (int k = 0; k < deg - 1; k++) {
509 for (auto& [ip1, fip1] : fs) {
510 int i = ip1 - 1;
511 if (k < i) break;
512 g[k + 1] += fip1 * g[k - i] * (i + 1);
513 }
514 g[k + 1] *= invs[k + 1];
515 }
516 return g;
517}
518
519template <typename mint>
521 long long k, int deg = -1) {
522 if (deg == -1) deg = f.size();
523 if (k == 0) {
524 FormalPowerSeries<mint> g(deg);
525 if (deg) g[0] = 1;
526 return g;
527 }
528 int zero = 0;
529 while (zero != (int)f.size() and f[zero] == 0) zero++;
530 if (zero == (int)f.size() or __int128_t(zero) * k >= deg) {
531 return FormalPowerSeries<mint>(deg, 0);
532 }
533 if (zero != 0) {
534 FormalPowerSeries<mint> suf{begin(f) + zero, end(f)};
535 auto g = sparse_pow(suf, k, deg - zero * k);
536 FormalPowerSeries<mint> h(zero * k, 0);
537 copy(begin(g), end(g), back_inserter(h));
538 return h;
539 }
540
541 int mod = mint::get_mod();
542 static vector<mint> invs{1, 1};
543 while ((int)invs.size() <= deg) {
544 int i = invs.size();
545 invs.push_back((-invs[mod % i]) * (mod / i));
546 }
547
548 vector<pair<int, mint>> fs;
549 for (int i = 1; i < (int)f.size(); i++) {
550 if (f[i] != 0) fs.emplace_back(i, f[i]);
551 }
552
553 FormalPowerSeries<mint> g(deg);
554 g[0] = f[0].pow(k);
555 mint denom = f[0].inverse();
556 k %= mint::get_mod();
557 for (int a = 1; a < deg; a++) {
558 for (auto& [i, f_i] : fs) {
559 if (a < i) break;
560 g[a] += f_i * g[a - i] * ((k + 1) * i - a);
561 }
562 g[a] *= denom * invs[a];
563 }
564 return g;
565}
566
567template <typename mint>
569 if (!ntt_ptr) ntt_ptr = new NTT<mint>;
570}
571
572template <typename mint>
574 const FormalPowerSeries<mint>& r) {
575 if (this->empty() || r.empty()) {
576 this->clear();
577 return *this;
578 }
579 set_fft();
580 auto ret = static_cast<NTT<mint>*>(ntt_ptr)->multiply(*this, r);
581 return *this = FormalPowerSeries<mint>(ret.begin(), ret.end());
582}
583
584template <typename mint>
585void FormalPowerSeries<mint>::ntt() {
586 set_fft();
587 static_cast<NTT<mint>*>(ntt_ptr)->ntt(*this);
588}
589
590template <typename mint>
591void FormalPowerSeries<mint>::intt() {
592 set_fft();
593 static_cast<NTT<mint>*>(ntt_ptr)->intt(*this);
594}
595
596template <typename mint>
598 set_fft();
599 static_cast<NTT<mint>*>(ntt_ptr)->ntt_doubling(*this);
600}
601
602template <typename mint>
604 set_fft();
605 return static_cast<NTT<mint>*>(ntt_ptr)->pr;
606}
607
608template <typename mint>
609FormalPowerSeries<mint> FormalPowerSeries<mint>::inv(int deg) const {
610 assert((*this)[0] != mint(0));
611 if (deg == -1) deg = (int)this->size();
612 FormalPowerSeries<mint> res(deg);
613 res[0] = {mint(1) / (*this)[0]};
614 for (int d = 1; d < deg; d <<= 1) {
615 FormalPowerSeries<mint> f(2 * d), g(2 * d);
616 for (int j = 0; j < min((int)this->size(), 2 * d); j++) f[j] = (*this)[j];
617 for (int j = 0; j < d; j++) g[j] = res[j];
618 f.ntt();
619 g.ntt();
620 for (int j = 0; j < 2 * d; j++) f[j] *= g[j];
621 f.intt();
622 for (int j = 0; j < d; j++) f[j] = 0;
623 f.ntt();
624 for (int j = 0; j < 2 * d; j++) f[j] *= g[j];
625 f.intt();
626 for (int j = d; j < min(2 * d, deg); j++) res[j] = -f[j];
627 }
628 return res.pre(deg);
629}
630
631template <typename mint>
632FormalPowerSeries<mint> FormalPowerSeries<mint>::exp(int deg) const {
633 using fps = FormalPowerSeries<mint>;
634 assert((*this).size() == 0 || (*this)[0] == mint(0));
635 if (deg == -1) deg = this->size();
636
637 fps inv;
638 inv.reserve(deg + 1);
639 inv.push_back(mint(0));
640 inv.push_back(mint(1));
641
642 auto inplace_integral = [&](fps& F) -> void {
643 const int n = (int)F.size();
644 auto mod = mint::get_mod();
645 while ((int)inv.size() <= n) {
646 int i = inv.size();
647 inv.push_back((-inv[mod % i]) * (mod / i));
648 }
649 F.insert(begin(F), mint(0));
650 for (int i = 1; i <= n; i++) F[i] *= inv[i];
651 };
652
653 auto inplace_diff = [](fps& F) -> void {
654 if (F.empty()) return;
655 F.erase(begin(F));
656 mint coeff = 1, one = 1;
657 for (int i = 0; i < (int)F.size(); i++) {
658 F[i] *= coeff;
659 coeff += one;
660 }
661 };
662
663 fps b{1, 1 < (int)this->size() ? (*this)[1] : 0}, c{1}, z1, z2{1, 1};
664 for (int m = 2; m < deg; m *= 2) {
665 auto y = b;
666 y.resize(2 * m);
667 y.ntt();
668 z1 = z2;
669 fps z(m);
670 for (int i = 0; i < m; ++i) z[i] = y[i] * z1[i];
671 z.intt();
672 fill(begin(z), begin(z) + m / 2, mint(0));
673 z.ntt();
674 for (int i = 0; i < m; ++i) z[i] *= -z1[i];
675 z.intt();
676 c.insert(end(c), begin(z) + m / 2, end(z));
677 z2 = c;
678 z2.resize(2 * m);
679 z2.ntt();
680 fps x(begin(*this), begin(*this) + min<int>(this->size(), m));
681 x.resize(m);
682 inplace_diff(x);
683 x.push_back(mint(0));
684 x.ntt();
685 for (int i = 0; i < m; ++i) x[i] *= y[i];
686 x.intt();
687 x -= b.diff();
688 x.resize(2 * m);
689 for (int i = 0; i < m - 1; ++i) x[m + i] = x[i], x[i] = mint(0);
690 x.ntt();
691 for (int i = 0; i < 2 * m; ++i) x[i] *= z2[i];
692 x.intt();
693 x.pop_back();
694 inplace_integral(x);
695 for (int i = m; i < min<int>(this->size(), 2 * m); ++i) x[i] += (*this)[i];
696 fill(begin(x), begin(x) + m, mint(0));
697 x.ntt();
698 for (int i = 0; i < 2 * m; ++i) x[i] *= y[i];
699 x.intt();
700 b.insert(end(b), begin(x) + m, end(x));
701 }
702 return fps{begin(b), begin(b) + deg};
703}
704
705
706
707template <typename mint>
708FormalPowerSeries<mint> sqrt(const FormalPowerSeries<mint> &f, int deg = -1) {
709 if (deg == -1) deg = (int)f.size();
710 if ((int)f.size() == 0) return FormalPowerSeries<mint>(deg, 0);
711 if (f[0] == mint(0)) {
712 for (int i = 1; i < (int)f.size(); i++) {
713 if (f[i] != mint(0)) {
714 if (i & 1) return {};
715 if (deg - i / 2 <= 0) break;
716 auto ret = sqrt(f >> i, deg - i / 2);
717 if (ret.empty()) return {};
718 ret = ret << (i / 2);
719 if ((int)ret.size() < deg) ret.resize(deg, mint(0));
720 return ret;
721 }
722 }
723 return FormalPowerSeries<mint>(deg, 0);
724 }
725
726 int64_t sqr = mod_sqrt(f[0].get(), mint::get_mod());
727 if (sqr == -1) return {};
728 assert(sqr * sqr % mint::get_mod() == f[0].get());
729 FormalPowerSeries<mint> ret = {mint(sqr)};
730 mint inv2 = mint(2).inverse();
731 for (int i = 1; i < deg; i <<= 1) {
732 ret = (ret + f.pre(i << 1) * ret.inv(i << 1)) * inv2;
733 }
734 return ret.pre(deg);
735}
736
737template <typename mint>
739 const FormalPowerSeries<mint> &fre, const FormalPowerSeries<mint> &fim,
740 int deg = -1) {
741 using fps = FormalPowerSeries<mint>;
742 assert(fre.size() == 0 || fre[0] == mint(0));
743 assert(fim.size() == 0 || fim[0] == mint(0));
744 if (deg == -1) deg = (int)max(fre.size(), fim.size());
745 fps re({mint(1)}), im({mint(0)});
746
747 fps::set_fft();
748 if (fps::ntt_ptr == nullptr) {
749 for (int i = 1; i < deg; i <<= 1) {
750 fps dre = re.diff();
751 fps dim = im.diff();
752 fps fhypot = (re * re + im * im).inv(i << 1);
753 fps ere = dre * re + dim * im;
754 fps eim = dim * re - dre * im;
755 fps logre = (ere * fhypot).pre((i << 1) - 1).integral();
756 fps logim = (eim * fhypot).pre((i << 1) - 1).integral();
757 fps gre = (-logre) + mint(1) - fim.pre(i << 1);
758 fps gim = (-logim) + fre.pre(i << 1);
759 fps hre = (re * gre - im * gim).pre(i << 1);
760 fps him = (re * gim + im * gre).pre(i << 1);
761 swap(re, hre);
762 swap(im, him);
763 }
764 } else {
765 for (int i = 1; i < deg; i <<= 1) {
766 fps dre = re.diff();
767 fps dim = im.diff();
768 re.resize(i << 1);
769 im.resize(i << 1);
770 dre.resize(i << 1);
771 dim.resize(i << 1);
772 re.ntt();
773 im.ntt();
774 dre.ntt();
775 dim.ntt();
776 fps fhypot(i << 1), ere(i << 1), eim(i << 1);
777 for (int j = 0; j < 2 * i; j++) {
778 fhypot[j] = re[j] * re[j] + im[j] * im[j];
779 ere[j] = dre[j] * re[j] + dim[j] * im[j];
780 eim[j] = dim[j] * re[j] - dre[j] * im[j];
781 }
782 fhypot.intt();
783 fhypot = fhypot.inv(i << 1);
784 fhypot.resize(i << 2);
785 fhypot.ntt();
786 ere.ntt_doubling();
787 eim.ntt_doubling();
788 fps logre(i << 2), logim(i << 2);
789 for (int j = 0; j < 4 * i; j++) {
790 logre[j] = ere[j] * fhypot[j];
791 logim[j] = eim[j] * fhypot[j];
792 }
793 logre.intt();
794 logim.intt();
795 logre = logre.pre((i << 1) - 1).integral();
796 logim = logim.pre((i << 1) - 1).integral();
797 fps gre = (-logre) + mint(1) - fim.pre(i << 1);
798 fps gim = (-logim) + fre.pre(i << 1);
799 gre.resize(i << 2);
800 gim.resize(i << 2);
801 gre.ntt();
802 gim.ntt();
803 re.ntt_doubling();
804 im.ntt_doubling();
805 fps hre(i << 2), him(i << 2);
806 for (int j = 0; j < 4 * i; j++) {
807 hre[j] = re[j] * gre[j] - im[j] * gim[j];
808 him[j] = re[j] * gim[j] + im[j] * gre[j];
809 }
810 hre.intt();
811 him.intt();
812 hre = hre.pre(i << 1);
813 him = him.pre(i << 1);
814 swap(re, hre);
815 swap(im, him);
816 }
817 }
818 return make_pair(re.pre(deg), im.pre(deg));
819}
820
821// calculate F(x + a)
822template <typename mint>
824 Binomial<mint>& C) {
825 using fps = FormalPowerSeries<mint>;
826 int N = f.size();
827 for (int i = 0; i < N; i++) f[i] *= C.fac(i);
828 reverse(begin(f), end(f));
829 fps g(N, mint(1));
830 for (int i = 1; i < N; i++) g[i] = g[i - 1] * a * C.inv(i);
831 f = (f * g).pre(N);
832 reverse(begin(f), end(f));
833 for (int i = 0; i < N; i++) f[i] *= C.finv(i);
834 return f;
835}
FormalPowerSeries< mint > sparse_exp(const FormalPowerSeries< mint > &f, int deg=-1)
Definition fps2.hpp:490
FormalPowerSeries< mint > sparse_div(const FormalPowerSeries< mint > &f, const FormalPowerSeries< mint > &g, int deg=-1)
多項式/形式的冪級数ライブラリ @docs docs/fps/formal-power-series.md
Definition fps2.hpp:416
FormalPowerSeries< mint > sqrt(const FormalPowerSeries< mint > &f, int deg=-1)
Definition fps2.hpp:708
FormalPowerSeries< mint > sparse_log(const FormalPowerSeries< mint > &f, int deg=-1)
Definition fps2.hpp:460
pair< FormalPowerSeries< mint >, FormalPowerSeries< mint > > circular(const FormalPowerSeries< mint > &fre, const FormalPowerSeries< mint > &fim, int deg=-1)
Definition fps2.hpp:738
FormalPowerSeries< mint > TaylorShift(FormalPowerSeries< mint > f, mint a, Binomial< mint > &C)
Definition fps2.hpp:823
FormalPowerSeries< mint > sparse_inv(const FormalPowerSeries< mint > &f, int deg=-1)
Definition fps2.hpp:438
FormalPowerSeries< mint > sparse_pow(const FormalPowerSeries< mint > &f, long long k, int deg=-1)
Definition fps2.hpp:520
FPS & operator+=(const mint &r)
Definition fps2.hpp:235
FPS operator+(const mint &v) const
Definition fps2.hpp:290
FPS dot(FPS r) const
Definition fps2.hpp:313
FPS & operator-=(const mint &r)
Definition fps2.hpp:247
FPS operator-(const mint &v) const
Definition fps2.hpp:292
FPS operator>>(int sz) const
Definition fps2.hpp:326
FPS exp(int deg=-1) const
Definition fps2.hpp:632
FPS pow(int64_t k, int deg=-1) const
Definition fps2.hpp:373
FPS & operator+=(const FPS &r)
Definition fps2.hpp:229
FPS operator+(const FPS &r) const
Definition fps2.hpp:289
FPS inv(int deg=-1) const
Definition fps2.hpp:609
FPS diff() const
Definition fps2.hpp:339
FPS operator%(const FPS &r) const
Definition fps2.hpp:296
FPS & operator-=(const FPS &r)
Definition fps2.hpp:241
FPS operator/(const FPS &r) const
Definition fps2.hpp:295
FPS & operator/=(const FPS &r)
Definition fps2.hpp:258
static int ntt_pr()
Definition fps2.hpp:603
FPS integral() const
Definition fps2.hpp:350
static void set_fft()
Definition fps2.hpp:568
FPS pre(int sz) const
Definition fps2.hpp:320
FPS & operator%=(const FPS &r)
Definition fps2.hpp:283
FPS log(int deg=-1) const
Definition fps2.hpp:367
static void * ntt_ptr
Definition fps2.hpp:395
mint eval(mint x) const
Definition fps2.hpp:361
void ntt_doubling()
Definition fps2.hpp:597
FPS operator*(const FPS &r) const
Definition fps2.hpp:293
FPS & operator*=(const FPS &r)
FPS rev() const
Definition fps2.hpp:307
FPS operator-() const
Definition fps2.hpp:297
FPS operator-(const FPS &r) const
Definition fps2.hpp:291
FPS operator*(const mint &v) const
Definition fps2.hpp:294
FPS & operator*=(const mint &v)
Definition fps2.hpp:253
NTT Friendly 素数用 NTT 構造体
Definition fps2.hpp:2
void ntt(vector< mint > &a)
Definition fps2.hpp:172
static constexpr uint32_t mod
Definition fps2.hpp:38
void intt(vector< mint > &a)
Definition fps2.hpp:177
mint dy[level]
Definition fps2.hpp:41
void fft4(vector< mint > &a, int k)
Definition fps2.hpp:58
NTT()
Definition fps2.hpp:56
void ifft4(vector< mint > &a, int k)
Definition fps2.hpp:115
static constexpr uint32_t pr
Definition fps2.hpp:39
mint dw[level]
Definition fps2.hpp:41
vector< mint > multiply(const vector< mint > &a, const vector< mint > &b)
Definition fps2.hpp:184
void ntt_doubling(vector< mint > &a)
Definition fps2.hpp:213
static constexpr int level
Definition fps2.hpp:40
void setwy(int k)
Definition fps2.hpp:43
static constexpr uint32_t get_pr()
Definition fps2.hpp:3