题目链接:acwing
首先,我的第一反应就是将$a^0$,$a^1$,$a^2$,…,$a^n$全部加起来就是答案。
代码很短:

1
2
3
4
5
6
7
8
9
10
11
12
13
#include<bits/stdc++.h>
using namespace std;
int a,b;
long long ans=1,f=1;
int main() {
scanf("%d%d",&a,&b);
for(int i=1;i<=b;i++) {
f=f*a%9901;
ans=(ans+f)%9901;
}
printf("%lld",ans);
return 0;
}

然而,这段代码并不能通过这个程序。我想了半天才发现,在a是合数的情况下会产生错误。
正确代码(大概采用一个分治的想法):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include<bits/stdc++.h>
#define ll long long
using namespace std;
long long a,b;
long long ans=1,f=1;
map<ll,ll> y;
map<ll,ll>::iterator it;
void getprime() {
for(int i=2;i<=a;i++)
while(a%i==0) {
a/=i;
y[i]++;
}
}
long long quick_pow(long long n,long long p) {
if(p==0) return 1;
if(p==1) return n;
long long ans=0;
ans=quick_pow(n,p/2);
ans=(ans*ans)%9901;
if(p&1) ans=ans*n%9901;
return ans%9901;
}
int sum(long long p,long long c) {
if(c==0) return 1;
if(c==1) return 1+p;
if(c&1) return (sum(p,(c-1)/2)*(1+quick_pow(p,(c+1)/2)))%9901; //如果c是奇数
else return (quick_pow(p,c)+(1+quick_pow(p,c/2))*sum(p,c/2-1))%9901;
}
int main() {
scanf("%lld%lld",&a,&b);
if(a==0) {
printf("0");
exit(0);
}
getprime();
for(it=y.begin();it!=y.end();it++)
ans=(ans*sum(it->first,it->second*b))%9901;
printf("%lld",ans);
return 0;
}