关于常系数线性递推...

----------------------------------------------------------------------------------
最近研究了一下这个问题...
最后实现的时候花了整整两天(当然其中大部分时间都在浪...)
主要还是不会的太多了..=_=感觉学到了好多新知识...
然后在学习关于多项式求逆的那一套理论的时候,发现Picks博客上就有这个问题的解法而且貌似比我的简单得多..=_=
不过我和他的方法貌似不一样(窝并看不懂他在说什么...)
所以也算是自己的一个发现吧
另外期间大部分我不能解决的问题都是问的Skydec...
----------------------------------------------------------------------------------
首先是问题:
已知$f(n)$的递推式$f(n)=\sum_{i=0}^{k-1}a_{i}*f(n-k+i)$(给定$k$和向量$a$)
并且给出$f(0)$~$f(k-1)$
现在给定一个$n$,求$f(n)$
----------------------------------------------------------------------------------
$O(lg(n)*k^3))$做法:
可以使用矩阵乘法
首先可以构造出初始向量$a=(f(0),f(1)...f(k-1))^T$
再构造出转移矩阵$A$:
$\begin{pmatrix}0 & 1 & 0 & \cdots & 0 \\0 & 0 & 1 & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\0 & 0 & 0 & \cdots & 1 \\a_0 & a_1 & a_2 & \cdots & a_{k-1} \\ \end{pmatrix}$
然后$A^{n}*a$的第$1$行第$1$列就是答案
然后就可以快速幂辣
----------------------------------------------------------------------------------
$O(lg(n)*k^2))$做法:
首先得到了$f(n)=\sum_{i=0}^{k-1}a_{i}*f(n-2^{0}*k+i)$ ...式$1$
把式$1$代入式$1$:
$f(n)=\sum_{i=0}^{k-1}a_{i}*\sum_{j=0}^{k-1}a_{j}f(n-2^{1}*k+i+j)$ 
可以暴力求,最后得到的式子共有$2*k$项,即:
$f(n)=\sum_{i=0}^{2*k-1}a^{'}_{i}f(n-2^{1}*k+i)$...式$1.5$ 
然后把式$1$代入式$1.5$的$f(n-2^{1}*k+k)$~$f(n-2^{1}*k+2*k-1)$这几项中,就得到了
$f(n)=\sum_{i=0}^{k-1}b_{i}*f(n-2^{1}*k+i)$ ...式$2$
这样不断地把自己代入自己,然后把式$1$代入自己的后$k$项,就可以的到式$p$:(我把$f(n-2^{p+1}*k+i)$叫做第$i$项)
$f(n)=\sum_{i=0}^{k-1}b_{i}*f(n-2^{p}*k+i)$ ...式$p$
这样就是一个倍增的过程,每次可以把$k$的系数乘以$2$
然后$n$是奇数的时候,需要往后推一项,把式$1$带进去即可
----------------------------------------------------------------------------------
$O(lg(n)*lg(k)*k))$做法:
重头戏来辣=w=
考虑上面那个方法的优化
主要分为两步:自己代入自己和把后$k$项展开到前$k$项
首先看自己代入自己这个过程:
$f(n)=\sum_{i=0}^{k-1}a_{i}*\sum_{j=0}^{k-1}a_{j}f(n-2^{p+1}*k+i+j)$ 
聪明的读者已经发现,这就是个卷积,所以做一个多项式乘法就好了...
 
现在假设已经求到了$f(n)=\sum_{i=0}^{k-1}c_{i}*f(n-2^{1}*k+i)$
另外为了方便表示,不放把$a$翻转一下然后移动一下位置(设这样得到的向量是$a^{'}$):
$f(n)=\sum_{i=1}^{k}a^{'}_{i}*f(n-i)$
然后来考虑第二步,这一步中,显然$k$~$2*k-1$这些项会对$0$~$k-1$这些项产生贡献
聪明的读者不难发现,第s项($k \le s \lt 2*k$)对第t项($0 \le t \lt k$)产生的贡献是:
$g(s,t)=c_{s}*\sum_{i=1}^{k}\sum_{b_{1},b_{2}...b_{i}且\sum_{j=1}^{i} b_{i}=s-t且 b_{j} \ge 1 且b_{i} \ge k-t}\prod_{i=1}^{k}a^{'}_{b_{i}}$
而第$t$项最后得到的贡献就是$d_t=\sum_{k \le s \lt 2*k}g(s,t)$
听起来很麻烦?
先假设不考虑$b_{i}项$
不放把$t$加上$b_{i}$
因此有:
$c^{'}_{t}=c_{s}*\sum_{i=1}^{k}\sum_{b_{1},b_{2}...b_{i}且\sum_{j=1}^{i} b_{i}=s-t且 b_{j} \ge 1}\prod_{i=1}^{k}a^{'}_{b_{i}}$
其中$k \le t \lt 2*k$
这就简单多了
假设$A$是$a^{'}$的生成函数
设$A[k]$表示$A$第$k$次项系数
那么不难发现上式就等于
$c^{'}_{t}=c_{s}*(\sum_{i=0}^{k}A^i)[s-t]$
=$c_{s}*((\sum_{i=0}^{+\infty}A^i)$ $mod$ $x^{k})[s-t]$
=$c_{s}*(1/(1-A))$ $mod$ $x^{k})[s-t]$
这个显然是一个卷积,于是一个多项式求逆一个多项式乘法就可以求到了
现在求到了$c^{'}$,考虑将他的各项减去$b_{i} (b_{i} \ge k-t)$来得到最后$d$
因此:
$d_t=\sum_{k \le s \lt 2*k}c^{'}_{s}*a^{'}_{s-t}$
然后这也显然是个卷积,多项式乘法即可
至此问题就解决了,最后$c^{'}_{i}=c_{i}+d_{i}(0 \le i \lt k)$就是展开完之后的$f(n)$的表达式的前$k$项
只用到了多项式乘法和多项式求逆,这两个都是$O(lg(k)*k)$的
然后算上倍增的时间,最后的时间复杂度$O(lg(n)*lg(k)*k))$
 
最后,实现的时候,多项式求逆中会爆精度,所以只能用NTT做(或者用分块乘法...)...
----------------------------------------------------------------------------------
这里有一个很挫的实现:(只实现了最后一步,也就是把$f(n)$含有$2*k$项的表达式展开成只含有$k$项的表达式,这是最难也是最关键的一步)
(输入输出格式是:第一行输入$k$,然后输入$a^{'}$,然后输入有$2*n$项的$c$,输出展开之后只有$n$项的$c^{'}$)

#include <stdio.h>

#include <string.h>

#include <math.h>

#include <time.h>



typedef long long Long;

typedef double Ld;



const int N = 800000*2;

const Ld PI = acos(-1);

const int P = 998244353;

const int W = 3;



int pow(int a,int b)

{

	a %= P;

	if(a < 0)

	a += P; 

	int r = 1.;

	while(b)

	{

		if(b & 1)

		r = (1LL * r * a) % (1LL * P);

		a = (1LL * a * a) % (1LL * P);

		b >>= 1;

	}

	return r;

}



namespace NTT

{

	int wn[N + 10];

	int rev[N + 10];

	void init(int n)

	{

		int log2n = 0;

		int nn = (n >> 1);

		while(nn)

		{

			log2n ++;

			nn >>= 1;

		}



		int num = 0;

		for(int i = 1;i <= n;i <<= 1)

		{

			wn[num++] = pow(W,(P-1)/i);

		}



		int x = 0;

		int y = 0;

		for(int i = 0;i < n;i++)

		{

			x = i;

			y = 0;

			for(int j = 1;j <= log2n;j++)

			{

				y <<= 1;

				y |= (x & 1);

				x >>= 1;

			}

			rev[i] = y;

		}

	}



	int buf[N + 10];

	void NTT(int * a,int n,int s)

	{

		for(int i = 0;i < n;i++)

		buf[i] = a[rev[i]];

		for(int i = 0;i < n;i++)

		a[i] = buf[i];



		int t = 2;

		int div2 = 0;

		int l = 0;

		int num = 0;

		while(t <= n)

		{

			div2 = (t >> 1);

			l = n / t;

			int ww = wn[++num];

			if(s)

			ww = pow(ww,P-2);

			for(int i = 0;i < n;i += t)

			{

				int w = 1;

				for(int j = 0;j < div2;j++)

				{

					int x = a[i+j];

					int y = (1LL * a[i+j+div2] * w) % (1LL * P);

					a[i+j] = (x + y) % P;

					a[i+j+div2] = (x + P - y) % P;

					w = (1LL * w * ww) % (1LL * P);

				}

			}



			t <<= 1;

		}

		if(s)

		{

			int di = pow(n,P-2);

			for(int i = 0;i < n;i++)

			a[i] = (1LL * a[i] * di) % (1LL * P);

		}

	}



	void mul(int * a,int * b,int * c,int n)

	{

		init(2 * n);

		NTT(a,2 * n,0);

		NTT(b,2 * n,0);

		for(int i = 0;i < 2 * n;i++)

		c[i] = (1LL * a[i] * b[i]) % (1LL * P);

		NTT(c,2 * n,1);

	}



	int tmp[N + 10];

	void get_inv(int * a,int * b,int t)//b = a^-1 mod x^t

	{

		if(t == 1)

		{

			b[0] = pow(a[0],P-2);

			return ;

		}



		get_inv(a,b,(t + 1) >> 1);



		int k = 0;

		for(k = 1;k <= (t << 1) + 3;k <<= 1);



		for(int i = 0;i < t;i++)

		tmp[i] = a[i];

		for(int i = t;i < k;i++)

		tmp[i] = 0;



		init(k);



		NTT(tmp,k,0);

		NTT(b,k,0);



		for(int i = 0;i < k;i++)

		{

			int val = (1LL * b[i] * b[i]) % (1LL * P);

			tmp[i] = (1LL * tmp[i] * val) % (1LL * P);

		}

		NTT(tmp,k,1);

		NTT(b,k,1);



		for(int i = 0;i < k;i++)

		b[i] = (2LL * b[i] + P - tmp[i]) % P; 

		for(int i = t;i < k;i++)

		b[i] = 0;

	}

};



int pa[N + 10]; 

int a[N + 10];

int d[N + 10];//d[i] = c[i+n]

int c[N + 10];

int b[N + 10];

int p[N + 10];



int main()

{

	FILE * fin = fopen("test.in","r");

	FILE * fout = fopen("test.out","w");



	int n = 0;

	fscanf(fin,"%d",&n);



	for(int i = 1;i <= n;i++)

	{

		fscanf(fin,"%d",a + i);

		a[i] = ((a[i] % P) + P) % P;

	}



	for(int i = 0;i < 2 * n;i++)

	{

		fscanf(fin,"%d",c + i);

		c[i] = ((c[i] % P) + P) % P;

		if(i >= n)

		d[i] = c[i];

	}



	{

		for(int i = 1;i <= n;i++)

		pa[i] = -a[i];

		pa[0] ++;

		NTT::get_inv(pa,b,n);

	}



	{

		for(int i = 0;i <= (n>>1);i++)

		{

			if(i < n-i)

			b[i] ^= b[n-i] ^= b[i] ^= b[n-i];

		}



		int e = 3 * n;

		int q = 1;

		while(q <= e)

		q <<= 1;



		NTT::mul(d,b,p,q);



		for(int i = 3 * n - 1;i >= 2 * n;i--)

		{

			p[i-n] = p[i];

			p[i] = 0;

		}

	}



	{

		memset(b,0,sizeof(b));



		for(int i = 0;i <= n;i++)

		{

			if(i < 2*n-i)

			p[i] ^= p[2*n-i] ^= p[i] ^= p[2*n-i];

		}



		int e = 6 * n;

		int q = 1;

		while(q <= e)

		q <<= 1;



		NTT::mul(a,p,b,q);



		for(int i = 0;i < 2 * n;i++)

		c[i] += b[2*n-i];

	}



	for(int i = 0;i < n;i++)

	{

		if(i == n-1)

		fprintf(fout,"%d\n",((c[i] % P) + P) % P);

		else fprintf(fout,"%d ",((c[i] % P) + P) % P);

	}



	fclose(fin);

	fclose(fout);



	return 0;

}
----------------------------------------------------------------------------------