Mersenne Twister 梅森旋转算法笔记

梅森旋转算法简介

梅森旋转算法(Mersenne Twister Algorithm,简称 MT)是为了解决过去伪随机数发生器(Pseudo-Random Number Generator,简称 PRNG)产生的伪随机数质量不高而提出的新算法。该算法由松本眞(Makoto Matsumoto)和西村拓士(Takuji Nishimura)在 1997 年提出,期间还得到了「算法之神」高德纳(Donald Ervin Knuth)的帮助。

Mersenne Twister这个名字来自周期长度取自梅森质数的这样一个事实。这个算法通常使用两个相近的变体,不同之处在于使用了不同的梅森素数。一个更新的和更常用的是MT19937, 32位字长。还有一个变种是64位版的MT19937-64。对于一个k位的长度,Mersenne Twister会在\([0,2^{k}-1]\)的区间之间生成离散型均匀分布的随机数。

其是多种语言的默认随机数算法,如Python、PHP和Matlab等。同时从C++11开始,C++中可以使用 std::mtl9937_64 使用梅森旋转算法。

目前被视为是最好的随机数算法。

算法描述

以下使用 MT19937 代替称呼梅森旋转算法。

MT19937 需要有一个 624 大小的 int32 数组来保存寄存器状态

以下是算法步骤:

  1. 利用 seed 初始化寄存器状态
  2. 对寄存器状态进行旋转
  3. 根据寄存器状态提取伪随机数

下面使用 python 对算法进行剖析。

MT19937 初始化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class MT19937:
# 常量初始化
n = 624
m = 397
a = 0x9908b0df
b = 0x9d2c5680
c = 0xefc60000
kInitOperand = 0x6c078965
kMaxBits = 0xffffffff
kUpperBits = 0x80000000
kLowerBits = 0x7fffffff
# 寄存器状态
register = [0] * n
# 寄存器状态量
state = 0

# 将数字变为 32 位整数
def _int32(self, x):
return x & 0xffffffff

初始化寄存器状态

1
2
3
4
5
6
7
8
9
10
# 初始化寄存器状态
def seed(self, seed):
self.register[0] = seed
for i in range(1, 624):
self.register[i] = self._int32(self.kInitOperand *
(self.register[i-1] ^ self.register[i-1] >> 30) + i)

def __init__(self, seed = None):
if seed:
self.seed(seed)

旋转

1
2
3
4
5
6
7
8
9
10
# 旋转
def twist(self):
for i in range(self.n):
x = self._int32((self.register[i] & self.kUpperBits) +
(self.register[(i + 1) % self.n] & self.kLowerBits))
self.register[i] = self.register[(i + self.m) % self.n] ^ (x >> 1)

if x & 1 != 0:
self.register[i] ^= 0x9908b0df
self.index = 0

提取伪随机数

1
2
3
4
5
6
7
8
9
10
11
12
13
# 提取
def extract(self):
if self.state >= self.n:
self.twist()

x = self.register[self.state]
x ^= x >> 11
x ^= (x << 7) & self.b
x ^= (x << 15) & self.c
x ^= x >> 18

self.state += 1
return self._int32(x)

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class MT19937:
n = 624
m = 397
a = 0x9908b0df
b = 0x9d2c5680
c = 0xefc60000
kInitOperand = 0x6c078965
kMaxBits = 0xffffffff
kUpperBits = 0x80000000
kLowerBits = 0x7fffffff
register = [0] * n
state = 0

# 将数字变为 32 位整数
def _int32(self, x):
return x & 0xffffffff

# 初始化寄存器状态
def seed(self, seed):
self.register[0] = seed
for i in range(1, self.n):
self.register[i] = self._int32(self.kInitOperand *
(self.register[i-1] ^ self.register[i-1] >> 30) + i)

def __init__(self, seed = None):
if seed:
self.seed(seed)

# 旋转
def twist(self):
for i in range(self.n):
x = self._int32((self.register[i] & self.kUpperBits) +
(self.register[(i + 1) % self.n] & self.kLowerBits))
self.register[i] = self.register[(i + self.m) % self.n] ^ (x >> 1)

if x & 1 != 0:
self.register[i] ^= self.a
self.state = 0

# 提取
def extract(self):
if self.state >= self.n:
self.twist()

x = self.register[self.state]
x ^= x >> 11
x ^= (x << 7) & self.b
x ^= (x << 15) & self.c
x ^= x >> 18

self.state += 1
return self._int32(x)

def __call__(self):
return self.extract()

状态破解

梅森旋转算法的设计目的是优秀的伪随机数发生算法,而不是产生密码学上安全的随机数。从梅森旋转算法的结构上说,其提取算法 extract 完全基于二进制的按位异或;而二进制按位异或是可逆的,故而 extract 是可逆的。这就意味着,攻击者可以从梅森旋转算法的输出,逆推出产生该输出的内部寄存器状态 register[state]。若攻击者能够获得连续的至少 n个寄存器状态,那么攻击者就能预测出接下来的随机数序列。

以下我们用 python 对算法逐步破解

右移位后异或逆向

首先观察原函数。

1
2
3
4
def right_shift_xor(value, shift):
result = value
result ^= (result >> shift)
return result

简单起见,我们观察一个 8 位二进制数,右移 3 位后异或的过程。

1
2
3
value:    1101 0010
shifted: 0001 1010 # 010 (>> 3)
result: 1100 1000

首先,观察到 result 的最高 shift 位与 value 的最高 shift 位是一样的。因此,在 result 的基础上,我们可以将其与一个二进制遮罩取与,得到 value 的最高 shift 位。这个遮罩应该是:1111 1111 << (8 - 3) = 1110 0000。于是我们得到 1100 0000

其次,注意到对于异或运算有如下事实:a ^ b ^ b = a。依靠二进制遮罩,我们已经获得了 value 的最高 shift 位。因此,我们也就能得到 shifted 的最高 2 * shift 位。它应该是 1100 0000 >> 3 = 0001 1000。将其与 result 取异或,则能得到 value 的最高 2 * shift 位。于是我们得到 1101 0000

故我们可以逆向出代码:

1
2
3
4
5
def unshiftRight(self, x, shift):
res = x
for i in range(32):
res = x ^ res >> shift
return res

左移位后异或逆向

同理可得代码

1
2
3
4
5
def unshiftLeft(self, x, shift, mask):
res = x
for i in range(32):
res = x ^ (res << shift & mask)
return res

提取伪随机数逆向

1
2
3
4
5
6
7
8
b = 0x9d2c5680
c = 0xefc60000
def untemper(self, v):
v = self.unshiftRight(v, 18)
v = self.unshiftLeft(v, 15, self.c)
v = self.unshiftLeft(v, 7, self.b)
v = self.unshiftRight(v, 11)
return v

通过输出参数逆向

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# forward 表示是否需要回到目前状态
def go(self, outputs, forward=True):
# 还原的寄存器状态
result_state = None

# 至少需要 624 个寄存器状态
assert len(outputs) >= 624

ivals = []
for i in range(624):
ivals.append(self.untemper(outputs[i]))

if len(outputs) >= 625:
# 此时可以使用后面的数据进行验证
challenge = outputs[624]
for i in range(1, 626):
state = (3, tuple(ivals+[i]), None)
r = random.Random()
r.setstate(state)

if challenge == r.getrandbits(32):
result_state = state
break
else:
# 如果刚好是 624 个寄存器状态
result_state = (3, tuple(ivals+[624]), None)

# 利用 python 自带的 mt19937 random 库
rand = random.Random()
rand.setstate(result_state)

# 如果需要到输出后的状态,则进行比较判断
if forward:
for i in range(624, len(outputs)):
assert rand.getrandbits(32) == outputs[i]

return rand

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class MT19937Recover:
def unshiftRight(self, x, shift):
res = x
for i in range(32):
res = x ^ res >> shift
return res

def unshiftLeft(self, x, shift, mask):
res = x
for i in range(32):
res = x ^ (res << shift & mask)
return res

def untemper(self, v):
v = self.unshiftRight(v, 18)
v = self.unshiftLeft(v, 15, 0xefc60000)
v = self.unshiftLeft(v, 7, 0x9d2c5680)
v = self.unshiftRight(v, 11)
return v

def go(self, outputs, forward=True):
result_state = None

assert len(outputs) >= 624

ivals = []
for i in range(624):
ivals.append(self.untemper(outputs[i]))

if len(outputs) >= 625:
challenge = outputs[624]
for i in range(1, 626):
state = (3, tuple(ivals+[i]), None)
r = random.Random()
r.setstate(state)

if challenge == r.getrandbits(32):
result_state = state
break
else:
result_state = (3, tuple(ivals+[624]), None)

rand = random.Random()
rand.setstate(result_state)

if forward:
for i in range(624, len(outputs)):
assert rand.getrandbits(32) == outputs[i]

return rand

mtc = MT19937Recover()

现有框架

我们可以使用 Python 的 randcrack 模块破解梅森随机数,这是用来破解 Python 随机数的模块。

矩阵状态破解

我们对 extract_number 函数进行分析,设 state[i] 的二进制表示形式为 \[ x_0x_1\cdots x_{30}x_{31} \] 那么输出的随机数的二进制表示形式为 \[ z_0z_1\cdots z_{30}z_{31} \] 那么可以发现它们存在如下线性关系 \[ \begin{array}{l} z_{0}=x_{0} \oplus x_{4} \oplus x_{7} \oplus x_{15} \\ z_{1}=x_{1} \oplus x_{5} \oplus x_{16} \\ z_{2}=x_{2} \oplus x_{6} \oplus x_{13} \oplus x_{17} \oplus x_{24} \\ z_{3}=x_{3} \oplus x_{10} \\ z_{4}=x_{0} \oplus x_{4} \oplus x_{8} \oplus x_{11} \oplus x_{15} \oplus x_{19} \oplus x_{26} \\ z_{5}=x_{1} \oplus x_{5} \oplus x_{9} \oplus x_{12} \oplus x_{20} \\ z_{6}=x_{6} \oplus x_{10} \oplus x_{17} \oplus x_{21} \oplus x_{28} \\ z_{7}=x_{3} \oplus x_{7} \oplus x_{11} \oplus x_{14} \oplus x_{18} \oplus x_{22} \oplus x_{29} \\ z_{8}=x_{8} \oplus x_{12} \oplus x_{23} \\ z_{9}=x_{9} \oplus x_{13} \oplus x_{20} \oplus x_{24} \oplus x_{31} \\ z_{10}=x_{6} \oplus x_{10} \oplus x_{17} \\ z_{11}=x_{0} \oplus x_{11} \\ z_{12}=x_{1} \oplus x_{8} \oplus x_{12} \oplus x_{19} \\ z_{13}=x_{2} \oplus x_{9} \oplus x_{13} \oplus x_{17} \oplus x_{20} \oplus x_{28} \\ z_{14}=x_{3} \oplus x_{14} \oplus x_{18} \oplus x_{29} \\ z_{15}=x_{4} \oplus x_{15} \\ z_{16}=x_{5} \oplus x_{16} \\ z_{17}=x_{6} \oplus x_{13} \oplus x_{17} \oplus x_{24} \\ z_{18}=x_{0} \oplus x_{4} \oplus x_{15} \oplus x_{18} \\ z_{19}=x_{1} \oplus x_{5} \oplus x_{8} \oplus x_{15} \oplus x_{16} \oplus x_{19} \oplus x_{26} \\ z_{20}=x_{2} \oplus x_{6} \oplus x_{9} \oplus x_{13} \oplus x_{17} \oplus x_{20} \oplus x_{24} \\ z_{21}=x_{3} \oplus x_{17} \oplus x_{21} \oplus x_{28} \\ z_{22}=x_{0} \oplus x_{4} \oplus x_{8} \oplus x_{15} \oplus x_{18} \oplus x_{19} \oplus x_{22} \oplus x_{26} \oplus x_{29} \\ z_{23}=x_{1} \oplus x_{5} \oplus x_{9} \oplus x_{20} \oplus x_{23} \\ z_{24}=x_{6} \oplus x_{10} \oplus x_{13} \oplus x_{17} \oplus x_{20} \oplus x_{21} \oplus x_{24} \oplus x_{28} \oplus x_{31} \\ z_{25}=x_{3} \oplus x_{7} \oplus x_{11} \oplus x_{18} \oplus x_{22} \oplus x_{25} \oplus x_{29} \\ z_{26}=x_{8} \oplus x_{12} \oplus x_{15} \oplus x_{23} \oplus x_{26} \\ z_{27}=x_{9} \oplus x_{13} \oplus x_{16} \oplus x_{20} \oplus x_{24} \oplus x_{27} \oplus x_{31} \\ z_{28}=x_{6} \oplus x_{10} \oplus x_{28} \\ z_{29}=x_{0} \oplus x_{11} \oplus x_{18} \oplus x_{29} \\ z_{30}=x_{1} \oplus x_{8} \oplus x_{12} \oplus x_{30} \\ z_{31}=x_{2} \oplus x_{9} \oplus x_{13} \oplus x_{17} \oplus x_{28} \oplus x_{31} \end{array} \] 也就是说,存在一个 \(\mathrm{GF}(2)\) 的矩阵可以将 \(X\) 变化为 \(Z\),即 \[ XT=Z \] 我们可以采用黑盒测试的方式计算出 \(T\),例如当 \(X=(1,0,\cdots,0)\) 时,得到矩阵 \(T\) 的第一行。

使用 sagemath 编写代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def extract_number(x):
x ^^= x >> 11
x ^^= (x << 7) & 2636928640
x ^^= (x << 15) & 4022730752
x ^^= x >> 18
return x & 0xffffffff

T = []
for i in range(32):
x = 1 << (31 - i)
x = extract_number(x)
T.append(x.digits(2, padto=32)[::-1])
T = matrix(GF(2), T)

def untemper(leak):
Z = matrix(GF(2), ZZ(leak).digits(2, padto=32)[::-1])
X = T.solve_left(Z)
return reduce(lambda x, y: (ZZ(x) << 1) + ZZ(y), list(X[0]))

同样可以获得 untemper

这种方式的好处是可以针对不同的梅森随机数参数。

旋转破解

我们考虑旋转算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
m = 397
a = 0x9908b0df
b = 0x9d2c5680
c = 0xefc60000
kInitOperand = 0x6c078965
kMaxBits = 0xffffffff
kUpperBits = 0x80000000
kLowerBits = 0x7fffffff
def twist(self):
for i in range(self.n):
x = self._int32((self.register[i] & 0x8) +
(self.register[(i + 1) % self.n] & self.kLowerBits))
self.register[i] = self.register[(i + self.m) % self.n] ^ (x >> 1)

if x & 1 != 0:
# 0b10011001000010001011000011011111 == 0x9908b0df
self.register[i] ^= 0x9908b0df
self.index = 0

发现新的 newState[i] 取决于旧的 state[i],state[i+1],state[i+397]

其中,由于 newState[i] = state[i+397] ^ (x >> 1),故 newState[i] 的最高位只与 state[i+397]0x9908b0df

由于旋转的时候是迭代进行的,故 newState[623]newState[396] 有关系,显然我们可以确定它,故我们通过此可以还原出 newState[396] ^ (x >> 1)state[623] 的最高位(即 x 的最高位)和 x 的最低位(会发生异或一定是 x & 1 != 0)。

而同理,由于 newState[396] 是确定的,所以可以解得 x >> 1 的结果,故我们最终还原出了 x,即 state[623] 的最高位和 newState[0] 的剩余位。

通过这种办法迭代可得旋转前的 state

可以写出代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
n = 624
m = 397
a = 0x9908b0df
b = 0x9d2c5680
c = 0xefc60000
kInitOperand = 0x6c078965
kMaxBits = 0xffffffff
kUpperBits = 0x80000000
kLowerBits = 0x7fffffff

def untwist(newState, flag: bool = True):
oldState = [0] * 624
for i in range(n - 1, -2, -1):
x = newState[i] ^ newState[(i + m) % n]
if x & kUpperBits == kUpperBits:
x ^= a
x <<= 1
x |= 1
else:
x <<= 1
if i > -1:
oldState[i] |= x & kUpperBits
if i + 1 < n:
oldState[i + 1] |= x & kLowerBits
if i == 227 and flag:
newState = list(newState[:227]) + oldState[227:]
return oldState