Solutions of XOR - MarisaOJ: Marisa Online Judge

Solutions of XOR

Select solution language

Write solution here.


dlnd_quochung09    Created at    1 likes

**Bài toán:** Đếm số lượng 1 ≤ (i,j) ≤ n thỏa mãn a[i] ≤ (a[i] ⊕ a[j]) ≤ a[j]. \ Để đơn giản trong xử lí ta mặc định bài toán là: \ Cho dãy a sắp xếp không giảm. Đếm số bộ 1 ≤ (i ≤ j) ≤ n thỏa mãn điều kiện trên. **Nhận xét:** \ **1.** Mỗi cặp (i,j) sao cho a[i] = 0 hoặc a[j] = 0 hoặc a[i] = a[j] = 0 đều thỏa mãn. \ -> Ta đếm riêng từng bộ như thế này. \ **2.** Với a[i] ⊕ a[j] = 0 thỏa mãn điều kiện khi và chỉ khi a[i] = a[j] = 0. \ Từ **(1)(2)**, bài toán của ta chỉ cần đếm các cặp nguyên dương (i < j, a[i] < a[j]) thỏa mãn điều kiện. **3.** Nhận thấy a[i] ⊕ a[j] ≤ a[j] khi và chỉ khi bit bật cao nhất của a[i] < bit bật cao nhất của a[j]. Xét a[j] = (1 << p) + r (p là bit bật cao nhất của a[j]). \ Bài toán biến đổi thành: Với mỗi a[j] đếm số lượng a[i] ⊕ r ≤ r (i < j, a[i] < a[j]). Cần sử dụng CTDL Trie nhị phân và một chút "trick lỏ" của DP-digit để giải. \ Nguồn tham khảo: https://wiki.vnoi.info/algo/string/trie | https://wiki.vnoi.info/algo/dp/digit-dp Code: ``` #include <bits/stdc++.h> using namespace std; #define ll long long #define FOR(i,a,b) for (int i=(a), _b = (b); i <= (_b); ++i) #define FORD(i,a,b) for (int i=(a), _b = (b); i >= (_b); --i) #define el cout << '\n' //--Compare------------------------------------------------------------------------------------ template<class X, class Y> inline bool maximize(X &x, const Y &y){return (x < y) ? x = y, 1 : 0;} template<class X, class Y> inline bool minimize(X &x, const Y &y){return (x > y) ? x = y, 1 : 0;} //--Process------------------------------------------------------------------------------------ constexpr int MAXN = 5e5 + 100, MAXBIT = 29; struct node { int child[2]; int cnt; node() : cnt(0) { child[0] = child[1] = -1; } } nodes[MAXN << 4]; int root, nodeCount = 0; int newNode(){ return nodeCount++; } void add(int x) { int p = root; FORD(i,MAXBIT,0) { bool bit = (x >> i) & 1; if (nodes[p].child[bit] == -1) nodes[p].child[bit] = newNode(); p = nodes[p].child[bit]; nodes[p].cnt++; } } int get(int x, int p = root, int pos = MAXBIT, bool tight = true) { if (p == -1) return 0; if (pos < 0 || !tight) return nodes[p].cnt; bool bit = (x >> pos) & 1; int res = 0; if (bit == 0) res += get(x, nodes[p].child[0], pos - 1, true); if (bit == 1) { res += get(x, nodes[p].child[1], pos - 1, false); res += get(x, nodes[p].child[0], pos - 1, true); } return res; } int n; int cnt0 = 0; vector <int> v; signed main(void) { cin.tie(nullptr)->sync_with_stdio(false); cin.exceptions(cin.failbit); cin >> n; FOR(i,1,n) { int x; cin >> x; cnt0 += (x == 0); if (x) v.emplace_back(x); } ll res = 1LL * cnt0 * n - 1LL * (cnt0 - 1LL) * cnt0 / 2LL; root = newNode(); sort(begin(v), end(v)); for (int x : v) { int r = x - (1 << (31 - __builtin_clz(x))); res += get(r); add(x); } cout << res; cerr << (1.0 * clock() / CLOCKS_PER_SEC); return 0; } ```