**Bài toán:** Đếm số lượng 1 ≤ (i,j) ≤ n thỏa mãn a[i] ≤ (a[i] ⊕ a[j]) ≤ a[j]. \
Để đơn giản trong xử lí ta mặc định bài toán là: \
Cho dãy a sắp xếp không giảm. Đếm số bộ 1 ≤ (i ≤ j) ≤ n thỏa mãn điều kiện trên.
**Nhận xét:** \
**1.** Mỗi cặp (i,j) sao cho a[i] = 0 hoặc a[j] = 0 hoặc a[i] = a[j] = 0 đều thỏa mãn. \
-> Ta đếm riêng từng bộ như thế này. \
**2.** Với a[i] ⊕ a[j] = 0 thỏa mãn điều kiện khi và chỉ khi a[i] = a[j] = 0. \
Từ **(1)(2)**, bài toán của ta chỉ cần đếm các cặp nguyên dương (i < j, a[i] < a[j]) thỏa mãn điều kiện.
**3.** Nhận thấy a[i] ⊕ a[j] ≤ a[j] khi và chỉ khi bit bật cao nhất của a[i] < bit bật cao nhất của a[j].
Xét a[j] = (1 << p) + r (p là bit bật cao nhất của a[j]). \
Bài toán biến đổi thành: Với mỗi a[j] đếm số lượng a[i] ⊕ r ≤ r (i < j, a[i] < a[j]).
Cần sử dụng CTDL Trie nhị phân và một chút "trick lỏ" của DP-digit để giải. \
Nguồn tham khảo: https://wiki.vnoi.info/algo/string/trie | https://wiki.vnoi.info/algo/dp/digit-dp
Code:
```
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define FOR(i,a,b) for (int i=(a), _b = (b); i <= (_b); ++i)
#define FORD(i,a,b) for (int i=(a), _b = (b); i >= (_b); --i)
#define el cout << '\n'
//--Compare------------------------------------------------------------------------------------
template<class X, class Y>
inline bool maximize(X &x, const Y &y){return (x < y) ? x = y, 1 : 0;}
template<class X, class Y>
inline bool minimize(X &x, const Y &y){return (x > y) ? x = y, 1 : 0;}
//--Process------------------------------------------------------------------------------------
constexpr int MAXN = 5e5 + 100, MAXBIT = 29;
struct node
{
int child[2];
int cnt;
node() : cnt(0) { child[0] = child[1] = -1; }
} nodes[MAXN << 4];
int root, nodeCount = 0;
int newNode(){ return nodeCount++; }
void add(int x)
{
int p = root;
FORD(i,MAXBIT,0)
{
bool bit = (x >> i) & 1;
if (nodes[p].child[bit] == -1) nodes[p].child[bit] = newNode();
p = nodes[p].child[bit];
nodes[p].cnt++;
}
}
int get(int x, int p = root, int pos = MAXBIT, bool tight = true)
{
if (p == -1) return 0;
if (pos < 0 || !tight) return nodes[p].cnt;
bool bit = (x >> pos) & 1;
int res = 0;
if (bit == 0) res += get(x, nodes[p].child[0], pos - 1, true);
if (bit == 1)
{
res += get(x, nodes[p].child[1], pos - 1, false);
res += get(x, nodes[p].child[0], pos - 1, true);
}
return res;
}
int n;
int cnt0 = 0;
vector <int> v;
signed main(void)
{
cin.tie(nullptr)->sync_with_stdio(false);
cin.exceptions(cin.failbit);
cin >> n;
FOR(i,1,n)
{
int x; cin >> x;
cnt0 += (x == 0);
if (x) v.emplace_back(x);
}
ll res = 1LL * cnt0 * n - 1LL * (cnt0 - 1LL) * cnt0 / 2LL;
root = newNode();
sort(begin(v), end(v));
for (int x : v)
{
int r = x - (1 << (31 - __builtin_clz(x)));
res += get(r);
add(x);
}
cout << res;
cerr << (1.0 * clock() / CLOCKS_PER_SEC);
return 0;
}
```
# **nên vô đây đọc: https://hackmd.io/@Pu71cy6tShWTJLBu6sd4gA/ByH5p1hUgx
# **Phân tích và Giải pháp**
### **1. Hướng tiếp cận**
> **Hướng tiếp cận:** (Giải thích thì dài, nhma code siêu gọn🐧)
>
> **Tóm gọn mục tiêu:** Tìm các cặp `(i, j)` sao cho: $$a_i \le (a_i \oplus a_j) \le a_j$$
>
> - Bài này đếm cặp => ta ưu tiên **sort** lại để đếm.
>
> **Lưu ý 2 tính chất quan trọng:**
> - $2^k > \sum_{p=0}^{k-1} 2^p$ (Bit lớn nhất quyết định độ lớn của số).
> - `1 ⊕ 1 = 0` (Phép XOR có thể "tắt" bit).
Sau khi sắp xếp, bài toán chuyển thành: Với mỗi `y = a[j]`, ta cần tìm số lượng `x = a[i]` (với `i ≤ j`) thỏa mãn $x \le (x \oplus y) \le y$.
---
### **2. Phân tích & Chứng minh**
Ta cần thỏa mãn đồng thời 2 điều kiện:
1. `(x ⊕ y) ≤ y`: Để kết quả không lớn hơn `y`, phép XOR phải "tắt" ít nhất một bit `1` của `y`. Điều này chỉ xảy ra khi `x` có chung một bit `1` nào đó với `y`.
2. `x ≤ (x ⊕ y)`: Đây là điều kiện khó hơn.
> **Chứng minh bằng ví dụ:**
>
> - **Trường hợp sai:** Nếu `MSB(x)` (bit 1 cao nhất của `x`) trùng với `MSB(y)`.
> - Ví dụ: `y = 169 (10101001)`, `x = 128 (10000000)`.
> - `MSB(x) = MSB(y) = 7`.
> - `x ⊕ y = 41 (00101001)`.
> - Ta thấy `41 < 128`, tức là `(x ⊕ y) < x`. **Không thỏa mãn.**
> - *Lý do:* Khi tắt đi bit quan trọng nhất, giá trị của số sẽ giảm mạnh, nhỏ hơn cả `x` ban đầu.
>
> - **Trường hợp đúng:** Để `x ≤ (x ⊕ y)`, `MSB(y)` phải được giữ nguyên trong phép XOR. Điều này có nghĩa là bit tại vị trí `MSB(y)` của `x` phải là `0` (luôn đúng khi `x < y`).
>
> **Kết hợp lại:**
>
> Để thỏa mãn cả 2 điều kiện, `x` phải "tắt" một bit nào đó của `y` nhưng không được đụng đến `MSB(y)`.
>
> => **Điều kiện cuối cùng:** `MSB(x)` phải trùng với một trong các bit `1` của `y`, **ngoại trừ** bit `MSB(y)`.
>
> Nói cách khác:
>
> > Với mỗi `y = a[j]`, ta chỉ cần tìm số lượng `x = a[i]` (`i < j`) sao cho `MSB(x)` trùng với một trong các bit từ **"vị trí 1 lớn nhì"** trở xuống của `y`.
---
### **3. Thuật toán & Code**
- Sắp xếp mảng `a`.
- Duyệt qua từng phần tử `y = a[i]`.
- Với mỗi `y`, ta duyệt qua các vị trí bit `k` của nó (từ bit 1 lớn nhì trở xuống).
- Nếu bit `k` của `y` đang bật, ta đếm xem đã có bao nhiêu số `x` đứng trước có `MSB(x) = k`.
- Dùng mảng `save[]` để lưu tần suất xuất hiện của các `MSB`. `save[k+1]` lưu số lượng số có `MSB` tại vị trí `k`.
> **CÁCH TÌM VỊ TRÍ BIT LỚN NHẤT CỦA 1 SỐ X:** `__lg(x)`
**Code tham khảo:**
```cpp
#include <bits/stdc++.h>
// Hàm tìm MSB, nếu x=0 thì trả về -1 để tiện xử lý
#define log2(x) ((x <= 0) ? -1 : __lg(x))
#define FOR(i, l, r, x) for(int i = l; i <= r; i += x)
#define FOD(i, l, r, x) for(int i = l; i >= r; i -= x)
#define pii pair<int, int>
#define fi first
#define se second
#define int long long
using namespace std;
const int N = 5e5 + 5;
const int mod = 1e9 + 7;
int n, a[N], save[N];
// Hàm tìm vị trí bit 1 lớn nhì
int bit2nd(int n) {
int inx = log2(n); // Tìm bit lớn nhất
if (inx == -1) return -1;
n ^= (1 << inx); // Tắt bit lớn nhất đi
return log2(n); // Tìm bit lớn nhất của số còn lại
}
// Kiểm tra bit thứ i của mask có bật không
bool isOn(int mask, int i) {
return ((mask >> i) & 1);
}
void solve() {
cin >> n;
FOR(i, 1, n, 1) {
cin >> a[i];
}
// Sắp xếp để đảm bảo khi xét a[i], các số đứng trước luôn nhỏ hơn hoặc bằng
sort(a + 1, a + n + 1);
int ans = 0;
// Duyệt qua mảng, coi a[i] là y (phần tử lớn hơn trong cặp)
FOR(i, 1, n, 1) {
int y = a[i];
// Tối ưu: chỉ cần xét các bit của y lên tới bit 1 lớn nhì của nó
int pos = bit2nd(y) + 1;
// Xử lý cặp (y, y). Nếu y=0, đây là cặp (0,0).
// Nếu y>0, đây là cặp (y,y) cũng luôn hợp lệ.
// Code gốc gộp lại: nếu y=0 thì ++ans, các cặp (0,0) sau sẽ được đếm ở save.
if (y == 0) {
++ans;
}
// Duyệt qua các khả năng cho MSB của x (phần tử nhỏ hơn).
// j là log2(x) + 1
FOR(j, 0, pos, 1) {
// Trường hợp x = 0. Cặp (0, y) luôn hợp lệ.
if (j == 0) {
ans += save[j]; // Cộng tất cả các số 0 đã gặp
continue;
}
// Đây là điều kiện cốt lõi từ chứng minh:
// Kiểm tra xem bit tại vị trí MSB của x (tức là j-1)
// có được bật trong y (tức là a[i]) hay không.
if (isOn(y, j - 1)) {
// Nếu có, tất cả các số x đã gặp có MSB tại vị trí đó đều tạo cặp hợp lệ.
ans += save[j];
}
}
// Cập nhật: "ghi nhớ" sự xuất hiện của y (a[i]) cho các vòng lặp sau.
// save[k+1] lưu số lượng số có MSB tại vị trí k.
save[log2(y) + 1] ++;
}
cout << ans << '\n';
}
signed main() {
#define name "task"
if (ifstream(name".inp")) {
freopen(name".inp", "r", stdin);
freopen(name".out", "w", stdout);
}
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
solve();
}
Gọi 2 số đó lần lượt là a,b (a >= b).
Trường hợp 1: ít nhất 1 trong 2 số bằng 0:
- Cặp số này luôn thõa mãn điều kiện, ta chỉ cần đếm riêng trường hợp này.
Trường hợp 2: a,b > 0:
Ta có điều kiện b < b ⊕ a < a. (ta có thể bỏ dấu bằng vì trường hợp đấy chỉ xảy ra trong trường hợp 1)
- Gọi x là bit bật cao nhất của a.
- Nhận xét: bit thứ x của b phải tắt (vì nếu bật thì khi ⊕ lại sẽ < b), từ đó đã thõa mãn điều kiện ⊕ của chúng > b.
- Gọi y là bit bật cao nhất của b.
- Nhận xét: bit thứ y của a phải bật (vì nếu tắt thì khi ⊕ lại sẽ > a), từ đó cũng thõa mãn điều kiện ⊕ của chúng < a.
Vậy ta chỉ cần làm như sau:
- không mất tính tổng quát ta có thể sort mảng a tăng dần
- TH1: dễ dàng xử lí
- TH2: với mỗi vị trí i ta cần đếm xem có bao nhiêu vị trí j < i sao cho bit bật cao nhất của a[j] < bit bật cao nhất của a[i] và a[i] cũng bật cái bit ấy.
Code:
```
#include <bits/stdc++.h>
#define ll long long
using namespace std;
int n;
int a[500005];
int cnt[31]; // cnt[i]: số lượng số có bit bật cao nhất là i
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> n;
for (int i=1;i<=n;i++){
cin >> a[i];
}
sort(a+1, a+1+n);
ll ans = 0;
int dem = 0; // lưu số lượng số = 0
for (int i=1;i<=n;i++){
ans += dem;
if (a[i] == 0){
dem++;
ans++;
continue;
}
int k = __lg(a[i]); // bit bật cao nhất
for (int j=0;j<k;j++){
if ((1 << j) & a[i]){
ans += cnt[j];
}
}
cnt[k]++;
}
cout << ans;
return 0;
}
```