0
votes

Is there any way to use a Segment Tree structure to compute the frequency of a given value in an array?

Suppose there is an array A of size N, and every element A[i] of the array contains the value 0, 1 or 2. I want to perform the following operations:

  • Compute the amount of zeroes in any range [a,b] of the array
  • Increment (mod 3) every element in in any range [a,b] of the array

Example: If A = [0,1,0,2,0]:

  • Query[2,4] must return 1 , since there is one 0 in the range [2,4]
  • Increment[2,4] updates A to [0,2,1,0,0]

This looks really similar to the Range Sum Query problem, which can be solved using Segment Trees (in this case using Lazy Propagation because of range updates), but i had no success adapting my seg tree code to this problem, because if i store the values in the tree like in a normal RSQ, any parent node which contains the value "3" (for example) wouldn't mean nothing, since with this information i can't extract how much zeroes are present in this range.

Thanks in advance!

--

EDIT:

Segment Trees are binary tree structures that store intervals related to an array in its nodes. The leaf nodes store the actual array cells, and each parent node stores a function f(node->left, node->right) of its children. Segment Trees are commonly used to perform Range Sum Queries, in which we want to compute the sum of all elements in a range [a,b] of the array. In this case, the function computed by the parent nodes is the sum of the value in its children nodes. We want to use segtrees to solve the Range Sum Query problem because it allows to solve it in O(log n) (we only need to descend the tree until we find the nodes that are completely covered by our range query), much better than the naive O(n) algorithm.

2
You might want to clue us in on what a segment tree is.Tim Biegeleisen
Edited the OP with aditional info!Lucas Sampaio

2 Answers

1
votes

Since actual array values are stored in the leaves (level L), let the nodes at level L - 1 store how many zeros they contain (which will be a value in the range [0, 2]). Other than that, everything is the same, the rest of the nodes will compute f(node->left, node->right) as node->left + node->right and the count of zeros will be propagated to the root.

After incrementing a range, if that range contained no zeros than nothing needs to be done. If however that range had zeros, then all those zeros will now be ones and the function value of current node (call it F) now becomes just zero. That change in the value now needs to be propagated upwards to the root, each time subtracting F from the function values.

0
votes

This question can be easily solved using Square root decomposition First create the new prefix sum array modulo each prefix sum by 3. Divide the whole array into sqrt(n) blocks . Each block will have counts of number of 0's,1's and 2's. Also create one temporary array which will contain the sum to be added to the elements of the block Here is the implementation in c++:

#include <bits/stdc++.h>
using namespace std;
#define si(a) scanf("%d",&a)
#define sll(a) scanf("%lld",&a)
#define sl(a) scanf("%ld",&a)
#define pi(a) printf("%d\n",a)
#define pl(a) printf("%ld\n",a)
#define pll(a) printf("%lld\n",a) 
#define sc(a) scanf("%c",&a)
#define pc(a) printf("%c",a)
#define ll long long
#define mod 1000000007
#define w while
#define pb push_back
#define mp make_pair
#define f first
#define s second
#define INF INT_MAX
#define fr(i,a,b) for(int i=a;i<=b;i++)



///////////////////////////////////////////////////////////////
struct block
{
    int one;
    int two;
    int zero;
    block()
    {
        one=two=zero=0;
    }
};
ll a[100005],a1[100005];
ll sum[400];
int main()
{
    int n,m;
    cin>>n>>m;
    string s;
    cin>>s;
    int N=(int)(sqrt(n));
    struct block b[N+10];
    for(int i=0;i<n;i++)
    {
        a[i]=s[i]-'0';
        a[i]%=3;
        a1[i]=a[i];
    }
    for(int i=1;i<n;i++)
    a[i]=(a[i]+a[i-1])%3;
    for(int i=0;i<n;i++)
    {
        if(a[i]==0)
        b[i/N].zero++;
        else if(a[i]==1)
        b[i/N].one++;
        else
        b[i/N].two++;
    }
    w(m--)
    {
        int type;
        si(type);
        if(type==1)
        {
            int ind,x;
            si(ind);
            si(x);
            x%=3;
            ind--;
                int diff=(x-a1[ind]+3)%3;
                if(diff==1)
                {
                    int st=ind/N;
                    int end=(n-1)/N;
                    int kl=(st+1)*N;
                    int hj=min(n,kl);
                    for(int i=st*N;i<hj;i++)
                    {
                        a[i]=(a[i]+sum[st])%3;
                    }
                    sum[st]=0;
                    for(int i=ind;i<hj;i++)
                    {
                        if(a[i]==0)
                        b[st].zero--;
                        else if(a[i]==1)
                        b[st].one--;
                        else
                        b[st].two--;


                        a[i]=(a[i]+diff)%3;



                        if(a[i]==0)
                        b[st].zero++;
                        else if(a[i]==1)
                        b[st].one++;
                        else
                        b[st].two++;
                    }

                    for(int i=st+1;i<=end;i++)
                    {
                        int yu=b[i].zero;
                        b[i].zero=b[i].two;
                        b[i].two=b[i].one;
                        b[i].one=yu;
                        sum[i]=(sum[i]+diff)%3;
                    }
                }
                else if(diff==2)
                {


                    int st=ind/N;
                    int end=(n-1)/N;
                    int kl=(st+1)*N;
                    int hj=min(n,kl);
                    for(int i=st*N;i<hj;i++)
                    {
                        a[i]=(a[i]+sum[st])%3;
                    }
                    sum[st]=0;
                    for(int i=ind;i<hj;i++)
                    {
                        if(a[i]==0)
                        b[st].zero--;
                        else if(a[i]==1)
                        b[st].one--;
                        else
                        b[st].two--;


                        a[i]=(a[i]+diff)%3;



                        if(a[i]==0)
                        b[st].zero++;
                        else if(a[i]==1)
                        b[st].one++;
                        else
                        b[st].two++;
                    }

                    for(int i=st+1;i<=end;i++)
                    {
                        int yu=b[i].zero;
                        b[i].zero=b[i].one;
                        b[i].one=b[i].two;
                        b[i].two=yu;
                        sum[i]=(sum[i]+diff)%3;
                    }
                }

            a1[ind]=x%3;
        }
        else
        {
            int l,r;
            ll x=0,y=0,z=0;
            si(l);
            si(r);
            l--;
            r--;
            int st=l/N;
            int end=r/N;
            if(st==end)
            {
                for(int i=l;i<=r;i++)
                {
                    ll op=(a[i]+sum[i/N])%3;
                    if(op==0)
                    x++;
                    else if(op==1)
                    y++;
                    else 
                    z++;
                }
            }
            else
            {
                for(int i=l;i<(st+1)*N;i++)
                {
                    ll op=(a[i]+sum[i/N])%3;
                    if(op==0)
                    x++;
                    else if(op==1)
                    y++;
                    else 
                    z++;
                }
                for(int i=end*N;i<=r;i++)
                {
                    ll op=(a[i]+sum[i/N])%3;
                    if(op==0)
                    x++;
                    else if(op==1)
                    y++;
                    else 
                    z++;
                }
                for(int i=st+1;i<=end-1;i++)
                {
                    x+=b[i].zero;
                    y+=b[i].one;
                    z+=b[i].two;
                }
            }
            ll temp=0;
            if(l!=0)
            {
                temp=(a[l-1]+sum[(l-1)/N])%3;
            }
            ll ans=(x*(x-1))/2;
            ans+=((y*(y-1))/2);
            ans+=((z*(z-1))/2);
            if(temp==0)
            ans+=x;
            else if(temp==1)
            ans+=y;
            else
            ans+=z;
            pll(ans);
        }
    }
    return 0;
}