梅森旋转算法简介
梅森旋转算法(Mersenne Twister Algorithm,简称
MT)是为了解决过去伪随机数发生器(Pseudo-Random Number Generator,简称
PRNG)产生的伪随机数质量不高而提出的新算法。该算法由松本眞(Makoto
Matsumoto)和西村拓士(Takuji Nishimura)在 1997
年提出,期间还得到了「算法之神」高德纳(Donald Ervin Knuth)的帮助。
Mersenne
Twister这个名字来自周期长度取自梅森质数的这样一个事实。这个算法通常使用两个相近的变体,不同之处在于使用了不同的梅森素数。一个更新的和更常用的是MT19937,
32位字长。还有一个变种是64位版的MT19937-64。对于一个k 位的长度,Mersenne
Twister会在\([0,2^{k}-1]\) 的区间之间生成离散型均匀分布的随机数。
其是多种语言的默认随机数算法,如Python、PHP和Matlab等。同时从C++11开始,C++中可以使用
std::mtl9937_64
使用梅森旋转算法。
目前被视为是最好的随机数算法。
算法描述
以下使用 MT19937 代替称呼梅森旋转算法。
MT19937 需要有一个 624 大小的 int32
数组来保存寄存器状态 。
以下是算法步骤:
利用 seed 初始化寄存器状态
对寄存器状态进行旋转
根据寄存器状态提取伪随机数
下面使用 python 对算法进行剖析。
MT19937 初始化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 class MT19937 : n = 624 m = 397 a = 0x9908b0df b = 0x9d2c5680 c = 0xefc60000 kInitOperand = 0x6c078965 kMaxBits = 0xffffffff kUpperBits = 0x80000000 kLowerBits = 0x7fffffff register = [0 ] * n state = 0 def _int32 (self, x ): return x & 0xffffffff
初始化寄存器状态
1 2 3 4 5 6 7 8 9 10 def seed (self, seed ): self.register[0 ] = seed for i in range (1 , 624 ): self.register[i] = self._int32(self.kInitOperand * (self.register[i-1 ] ^ self.register[i-1 ] >> 30 ) + i) def __init__ (self, seed = None ): if seed: self.seed(seed)
旋转
1 2 3 4 5 6 7 8 9 10 def twist (self ): for i in range (self.n): x = self._int32((self.register[i] & self.kUpperBits) + (self.register[(i + 1 ) % self.n] & self.kLowerBits)) self.register[i] = self.register[(i + self.m) % self.n] ^ (x >> 1 ) if x & 1 != 0 : self.register[i] ^= 0x9908b0df self.index = 0
提取伪随机数
1 2 3 4 5 6 7 8 9 10 11 12 13 def extract (self ): if self.state >= self.n: self.twist() x = self.register[self.state] x ^= x >> 11 x ^= (x << 7 ) & self.b x ^= (x << 15 ) & self.c x ^= x >> 18 self.state += 1 return self._int32(x)
完整代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 class MT19937 : n = 624 m = 397 a = 0x9908b0df b = 0x9d2c5680 c = 0xefc60000 kInitOperand = 0x6c078965 kMaxBits = 0xffffffff kUpperBits = 0x80000000 kLowerBits = 0x7fffffff register = [0 ] * n state = 0 def _int32 (self, x ): return x & 0xffffffff def seed (self, seed ): self.register[0 ] = seed for i in range (1 , self.n): self.register[i] = self._int32(self.kInitOperand * (self.register[i-1 ] ^ self.register[i-1 ] >> 30 ) + i) def __init__ (self, seed = None ): if seed: self.seed(seed) def twist (self ): for i in range (self.n): x = self._int32((self.register[i] & self.kUpperBits) + (self.register[(i + 1 ) % self.n] & self.kLowerBits)) self.register[i] = self.register[(i + self.m) % self.n] ^ (x >> 1 ) if x & 1 != 0 : self.register[i] ^= self.a self.state = 0 def extract (self ): if self.state >= self.n: self.twist() x = self.register[self.state] x ^= x >> 11 x ^= (x << 7 ) & self.b x ^= (x << 15 ) & self.c x ^= x >> 18 self.state += 1 return self._int32(x) def __call__ (self ): return self.extract()
状态破解
梅森旋转算法的设计目的是优秀的伪随机数发生算法,而不是产生密码学上安全的随机数。从梅森旋转算法的结构上说,其提取算法
extract
完全基于二进制的按位异或;而二进制按位异或是可逆的,故而
extract
是可逆的。这就意味着,攻击者可以从梅森旋转算法的输出,逆推出产生该输出的内部寄存器状态
register[state]
。若攻击者能够获得连续的至少
n
个寄存器状态,那么攻击者就能预测出接下来的随机数序列。
以下我们用 python 对算法逐步破解
右移位后异或逆向
首先观察原函数。
1 2 3 4 def right_shift_xor (value, shift ): result = value result ^= (result >> shift) return result
简单起见,我们观察一个 8 位二进制数,右移 3 位后异或的过程。
1 2 3 value: 1101 0010 shifted: 0001 1010 # 010 (>> 3 ) result: 1100 1000
首先,观察到 result
的最高 shift
位与
value
的最高 shift
位是一样的。因此,在
result
的基础上,我们可以将其与一个二进制遮罩取与,得到
value
的最高 shift
位。这个遮罩应该是:1111 1111 << (8 - 3) = 1110 0000
。于是我们得到
1100 0000
。
其次,注意到对于异或运算有如下事实:a ^ b ^ b = a
。依靠二进制遮罩,我们已经获得了
value
的最高 shift
位。因此,我们也就能得到
shifted
的最高 2 * shift
位。它应该是
1100 0000 >> 3 = 0001 1000
。将其与
result
取异或,则能得到 value
的最高
2 * shift
位。于是我们得到 1101 0000
。
故我们可以逆向出代码:
1 2 3 4 5 def unshiftRight (self, x, shift ): res = x for i in range (32 ): res = x ^ res >> shift return res
左移位后异或逆向
同理可得代码
1 2 3 4 5 def unshiftLeft (self, x, shift, mask ): res = x for i in range (32 ): res = x ^ (res << shift & mask) return res
提取伪随机数逆向
1 2 3 4 5 6 7 8 b = 0x9d2c5680 c = 0xefc60000 def untemper (self, v ): v = self.unshiftRight(v, 18 ) v = self.unshiftLeft(v, 15 , self.c) v = self.unshiftLeft(v, 7 , self.b) v = self.unshiftRight(v, 11 ) return v
通过输出参数逆向
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 def go (self, outputs, forward=True ): result_state = None assert len (outputs) >= 624 ivals = [] for i in range (624 ): ivals.append(self.untemper(outputs[i])) if len (outputs) >= 625 : challenge = outputs[624 ] for i in range (1 , 626 ): state = (3 , tuple (ivals+[i]), None ) r = random.Random() r.setstate(state) if challenge == r.getrandbits(32 ): result_state = state break else : result_state = (3 , tuple (ivals+[624 ]), None ) rand = random.Random() rand.setstate(result_state) if forward: for i in range (624 , len (outputs)): assert rand.getrandbits(32 ) == outputs[i] return rand
完整代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 class MT19937Recover : def unshiftRight (self, x, shift ): res = x for i in range (32 ): res = x ^ res >> shift return res def unshiftLeft (self, x, shift, mask ): res = x for i in range (32 ): res = x ^ (res << shift & mask) return res def untemper (self, v ): v = self.unshiftRight(v, 18 ) v = self.unshiftLeft(v, 15 , 0xefc60000 ) v = self.unshiftLeft(v, 7 , 0x9d2c5680 ) v = self.unshiftRight(v, 11 ) return v def go (self, outputs, forward=True ): result_state = None assert len (outputs) >= 624 ivals = [] for i in range (624 ): ivals.append(self.untemper(outputs[i])) if len (outputs) >= 625 : challenge = outputs[624 ] for i in range (1 , 626 ): state = (3 , tuple (ivals+[i]), None ) r = random.Random() r.setstate(state) if challenge == r.getrandbits(32 ): result_state = state break else : result_state = (3 , tuple (ivals+[624 ]), None ) rand = random.Random() rand.setstate(result_state) if forward: for i in range (624 , len (outputs)): assert rand.getrandbits(32 ) == outputs[i] return rand mtc = MT19937Recover()
现有框架
我们可以使用 Python 的 randcrack 模块破解梅森随机数,这是用来破解
Python 随机数的模块。
矩阵状态破解
我们对 extract_number
函数进行分析,设
state[i]
的二进制表示形式为 \[
x_0x_1\cdots x_{30}x_{31}
\] 那么输出的随机数的二进制表示形式为 \[
z_0z_1\cdots z_{30}z_{31}
\] 那么可以发现它们存在如下线性关系 \[
\begin{array}{l}
z_{0}=x_{0} \oplus x_{4} \oplus x_{7} \oplus x_{15} \\
z_{1}=x_{1} \oplus x_{5} \oplus x_{16} \\
z_{2}=x_{2} \oplus x_{6} \oplus x_{13} \oplus x_{17} \oplus x_{24} \\
z_{3}=x_{3} \oplus x_{10} \\
z_{4}=x_{0} \oplus x_{4} \oplus x_{8} \oplus x_{11} \oplus x_{15} \oplus
x_{19} \oplus x_{26} \\
z_{5}=x_{1} \oplus x_{5} \oplus x_{9} \oplus x_{12} \oplus x_{20} \\
z_{6}=x_{6} \oplus x_{10} \oplus x_{17} \oplus x_{21} \oplus x_{28} \\
z_{7}=x_{3} \oplus x_{7} \oplus x_{11} \oplus x_{14} \oplus x_{18}
\oplus x_{22} \oplus x_{29} \\
z_{8}=x_{8} \oplus x_{12} \oplus x_{23} \\
z_{9}=x_{9} \oplus x_{13} \oplus x_{20} \oplus x_{24} \oplus x_{31} \\
z_{10}=x_{6} \oplus x_{10} \oplus x_{17} \\
z_{11}=x_{0} \oplus x_{11} \\
z_{12}=x_{1} \oplus x_{8} \oplus x_{12} \oplus x_{19} \\
z_{13}=x_{2} \oplus x_{9} \oplus x_{13} \oplus x_{17} \oplus x_{20}
\oplus x_{28} \\
z_{14}=x_{3} \oplus x_{14} \oplus x_{18} \oplus x_{29} \\
z_{15}=x_{4} \oplus x_{15} \\
z_{16}=x_{5} \oplus x_{16} \\
z_{17}=x_{6} \oplus x_{13} \oplus x_{17} \oplus x_{24} \\
z_{18}=x_{0} \oplus x_{4} \oplus x_{15} \oplus x_{18} \\
z_{19}=x_{1} \oplus x_{5} \oplus x_{8} \oplus x_{15} \oplus x_{16}
\oplus x_{19} \oplus x_{26} \\
z_{20}=x_{2} \oplus x_{6} \oplus x_{9} \oplus x_{13} \oplus x_{17}
\oplus x_{20} \oplus x_{24} \\
z_{21}=x_{3} \oplus x_{17} \oplus x_{21} \oplus x_{28} \\
z_{22}=x_{0} \oplus x_{4} \oplus x_{8} \oplus x_{15} \oplus x_{18}
\oplus x_{19} \oplus x_{22} \oplus x_{26} \oplus x_{29} \\
z_{23}=x_{1} \oplus x_{5} \oplus x_{9} \oplus x_{20} \oplus x_{23} \\
z_{24}=x_{6} \oplus x_{10} \oplus x_{13} \oplus x_{17} \oplus x_{20}
\oplus x_{21} \oplus x_{24} \oplus x_{28} \oplus x_{31} \\
z_{25}=x_{3} \oplus x_{7} \oplus x_{11} \oplus x_{18} \oplus x_{22}
\oplus x_{25} \oplus x_{29} \\
z_{26}=x_{8} \oplus x_{12} \oplus x_{15} \oplus x_{23} \oplus x_{26} \\
z_{27}=x_{9} \oplus x_{13} \oplus x_{16} \oplus x_{20} \oplus x_{24}
\oplus x_{27} \oplus x_{31} \\
z_{28}=x_{6} \oplus x_{10} \oplus x_{28} \\
z_{29}=x_{0} \oplus x_{11} \oplus x_{18} \oplus x_{29} \\
z_{30}=x_{1} \oplus x_{8} \oplus x_{12} \oplus x_{30} \\
z_{31}=x_{2} \oplus x_{9} \oplus x_{13} \oplus x_{17} \oplus x_{28}
\oplus x_{31}
\end{array}
\] 也就是说,存在一个 \(\mathrm{GF}(2)\) 的矩阵可以将 \(X\) 变化为 \(Z\) ,即 \[
XT=Z
\] 我们可以采用黑盒测试的方式计算出 \(T\) ,例如当 \(X=(1,0,\cdots,0)\) 时,得到矩阵 \(T\) 的第一行。
使用 sagemath 编写代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def extract_number (x ): x ^^= x >> 11 x ^^= (x << 7 ) & 2636928640 x ^^= (x << 15 ) & 4022730752 x ^^= x >> 18 return x & 0xffffffff T = [] for i in range (32 ): x = 1 << (31 - i) x = extract_number(x) T.append(x.digits(2 , padto=32 )[::-1 ]) T = matrix(GF(2 ), T) def untemper (leak ): Z = matrix(GF(2 ), ZZ(leak).digits(2 , padto=32 )[::-1 ]) X = T.solve_left(Z) return reduce(lambda x, y: (ZZ(x) << 1 ) + ZZ(y), list (X[0 ]))
同样可以获得 untemper
。
这种方式的好处是可以针对不同的梅森随机数参数。
旋转破解
我们考虑旋转算法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 m = 397 a = 0x9908b0df b = 0x9d2c5680 c = 0xefc60000 kInitOperand = 0x6c078965 kMaxBits = 0xffffffff kUpperBits = 0x80000000 kLowerBits = 0x7fffffff def twist (self ): for i in range (self.n): x = self._int32((self.register[i] & 0x8 ) + (self.register[(i + 1 ) % self.n] & self.kLowerBits)) self.register[i] = self.register[(i + self.m) % self.n] ^ (x >> 1 ) if x & 1 != 0 : self.register[i] ^= 0x9908b0df self.index = 0
发现新的 newState[i]
取决于旧的
state[i],state[i+1],state[i+397]
。
其中,由于
newState[i] = state[i+397] ^ (x >> 1)
,故
newState[i]
的最高位只与 state[i+397]
和
0x9908b0df
。
由于旋转的时候是迭代进行的,故 newState[623]
与
newState[396]
有关系,显然我们可以确定它,故我们通过此可以还原出
newState[396] ^ (x >> 1)
,state[623]
的最高位(即 x
的最高位)和 x
的最低位(会发生异或一定是 x & 1 != 0
)。
而同理,由于 newState[396]
是确定的,所以可以解得
x >> 1
的结果,故我们最终还原出了 x
,即
state[623]
的最高位和 newState[0]
的剩余位。
通过这种办法迭代可得旋转前的 state
。
可以写出代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 n = 624 m = 397 a = 0x9908b0df b = 0x9d2c5680 c = 0xefc60000 kInitOperand = 0x6c078965 kMaxBits = 0xffffffff kUpperBits = 0x80000000 kLowerBits = 0x7fffffff def untwist (newState, flag: bool = True ): oldState = [0 ] * 624 for i in range (n - 1 , -2 , -1 ): x = newState[i] ^ newState[(i + m) % n] if x & kUpperBits == kUpperBits: x ^= a x <<= 1 x |= 1 else : x <<= 1 if i > -1 : oldState[i] |= x & kUpperBits if i + 1 < n: oldState[i + 1 ] |= x & kLowerBits if i == 227 and flag: newState = list (newState[:227 ]) + oldState[227 :] return oldState