class Solution { public: int trap(vector& h) { int n = h.size(); using pii = pair; using ll = long long; stack stk; int ret = 0; for(int i = 0; i < n; i++){ ll sum = 0; ll W = 0; int lastH = 0; while(!stk.empty() && stk.top().second < h[i]){ lastH = stk.top().second; W += stk.top().first; sum += 1LL * stk.top().first * stk.top().second; stk.pop(); } if(stk.empty()){ ret += W * lastH - sum; stk.push({1, h[i]}); }else{ ret += W * h[i] - sum; stk.push({W + 1, h[i]}); } } return ret; } };