不定期プログラミング覚え書き

青コーダーと黄コーダーの間を彷徨う社会人プロコン勢が余力のあるときに復習した内容をまとめるブログ

yukicoder #364: 門松木

問題

No.364 門松木 - yukicoder
すべての隣り合う3ノードが門松列になっている木を門松木と呼ぶ。
入力に木が与えられるので、その部分木である門松木のがもつ門松列の個数の最大値を答えよ

考えたこと

どう見ても木DPです……と思ってからDPを実装可能な状況にするまでに凄い時間かかった。
最終的に、
dp[n][o][f]: ノードnが両端のノードより大きい(o=0)/小さい(o=1)ような、門松木の最大の門松列の数。
但し、f=0の時はparent[n]を考慮せず、nとその子孫のみで門松木を作った場合で、f=1の時にはparent[n]を必ず利用する場合の値(なおかつ、parent[n]を含む門松列もカウントする)を格納するとする。
というdpテーブルに落ち着いた。

fというフラグが入っているのは、dp[n][o][f]の計算にあたって、nの孫の個数が依存するため、
その部分まで吸収した値をdp[n][1-o][1]として持たせよう、という算段である。
(ただそんなけったいなことをしたがゆえにだいぶバグって大変でもあった)

コード
int A[100000];
vector<int> edge[100000];
int parent[100000];
int dp[100000][2][2]; //dp[i][j]... node i as a parent, j=1 must use i's parent
void dfs(int p){
//   //cout <<p<<endl;
   for( int i = 0 ;i<(int)edge[p].size();i++){
      int nxt = edge[p][i];
      if( nxt!= parent[p] ){
         parent[nxt]=p;
         dfs(nxt);
      }
   }
}
int ans = 0;
int calc( int p ,int o, int f){
   map<int,int> m;
   //cout << p <<", "<< o <<", "<< f<<endl;
   if(dp[p][o][f]!=-1 ) return dp[p][o][f];
   for( int ei=0; ei < (int)edge[p].size(); ei++ ){
      int n = edge[p][ei];
      //cout << p <<"- "<< n << endl;
      if( n==parent[p]) continue;
      calc(n,0,0);
      calc(n,1,0);
      if( A[n]>A[p] &&o==0){
         m[A[n]]=max(m[A[n]],calc(n,1-o,1));
      }
      if( A[n]<A[p]&&o==1){
         m[A[n]]=max(m[A[n]],calc(n,1-o,1));
      }
   }

   map<int,int>::iterator it=m.begin();
   int cnt=0;
   int cnt_wo_parent=0;
   dp[p][o][0]=0;
   dp[p][o][1]=0;
   while(it!= m.end()){
      cnt++;
      dp[p][o][0]+=it->second;
      if( it->first!= A[parent[p]] ){
         cnt_wo_parent++;
         dp[p][o][1]+=it->second;
      }
      it++;
   }
   dp[p][o][0]+=cnt*(cnt-1)/2;
   dp[p][o][1]+=cnt_wo_parent*(cnt_wo_parent-1)/2;
   if( o==0 && A[parent[p]]>A[p]) dp[p][o][1]+=cnt_wo_parent;
   if( o==1 && A[parent[p]]<A[p]) dp[p][o][1]+=cnt_wo_parent;
   ans=max(dp[p][o][0],ans);
   if( parent[p]!=p ) ans=max(dp[p][o][1],ans);
   //cout << p << ", "<< o <<", "<< 0 <<": "<< dp[p][o][0]<<endl;
   //cout << p << ", "<< o <<", "<< 1 <<": "<< dp[p][o][1]<<endl;
   return dp[p][o][f];
}

int main(){
  int N;
  cin >> N;
  for( int i = 0 ; i <N; i++ ){
      cin >> A[i];
  }
  for( int i = 0;  i < N-1 ; i++ ){
      int x,y;
      cin >> x >> y;
      x--;y--;
      edge[x].push_back(y);
      edge[y].push_back(x);
  }
  memset(parent,-1,sizeof(parent));
  if( N <= 2 ){
   cout<<0<<endl;
   return 0;
  }
  int p=0;
  while( edge[p].size() < 2 ) p++;
  parent[p]=p;
  dfs(p);
  memset(dp,-1,sizeof(dp)); 
  calc(p,0,0);
  calc(p,1,0);
  cout << ans <<endl;
  return 0;
}