NTRU密码体制笔记

NTRU 原理

NTRU 是一种基于环与格的公钥密码系统,特点是密钥短且容易产生,算法运算速度快,所需存储空间小。

格问题

基本问题介绍

考虑一个简单的加密方式,随机选取素数 \(p,g\),随机数 \(f\),满足 \(g<f<p\),我们计算 \[ h=gf^{-1}\pmod{p} \] 那么公钥为 \((h,p)\),私钥为 \((f,g)\)

定义加密方式为 \[ c=r\cdot h+m\pmod{p} \] 其中,\(r\) 为随机数,满足 \(r<p\)

定义解密方式为 \[ c\cdot f=r\cdot g+m\cdot f\pmod{p} \] 在这里要满足 \(rg+mf<p\),那么可以得知 \[ a=rg+mf=mf\pmod{g} \] 那么求解问题变为 \[ m=af^{-1}\pmod{g} \]

在这里简化了很多步骤,比如说保证 \(\gcd(f,g)=1\) 等条件,其中 \(rg+mf<p\) 是通过控制 \(h, p,f,g\) 的大小实现的。

当我们手里有密文 \(c\),公钥 \((h,p)\) 时,要想破解明文,需要得知密钥 \((f,g)\)

攻击手段

首先考虑等式 \[ h=gf^{-1}\pmod{p} \] 对其进行变形,得到 \[ f\cdot h=g\pmod{p}=g+kp \] 改写为矩阵乘法,得到 \[ (f,-k)M=(f,g)\newline M=\begin{bmatrix} 1&h\newline 0&p \end{bmatrix} \] 可以证明向量 \((f,g)\) 在格 \(M\) 上,即 SVP 问题。

详细证明在另一篇文章中,暂不在这里讨论。

一个二维格使用 LLL 算法是容易求解的。

生成代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from Crypto.Util.number import getPrime, getRandomNBitInteger
from math import gcd

def gen_key():
p = getPrime(2048)
f = getRandomNBitInteger(1024)
g = getPrime(768)
assert gcd(f, p) == 1
assert gcd(f, g) == 1
h = pow(f, -1, p) * g % p
return (h, p), (f, g)

def encrypt(h, p, m):
r = getRandomNBitInteger(1024)
return (r * h + m) % p

def decrypt(f, g, c):
a = c * f % p
return a * pow(f, -1, g) % g

pubkey, prikey = gen_key()
h, p = pubkey
f, g = prikey
m = 1234567890123456789012345678901234567890
c = encrypt(h, p, m)

print(f'{p = }')
print(f'{h = }')
print(f'{c = }')

攻击代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from sage.all import *

p = 27594003389147607369876674962283540812923723484764896504615301667891769682392830700505146565378653539119600813995142280915192657472008749757053428617533851699041596289206274929457178724207236537704942588471957686793501676280083577625938001845200938875371121059921018339291926310189890910674291154395560577741025165026637266300202408364846598547810207047635737497984357341153634605326370312560794859152109470507930707976266978869848114875210278872802460387560896289447402484841694897117444468955804897349232138924990711736063541879374827091472825348941445564520557929969273920235937179897137468129678469879486475239587
h = 2071693178846115365256723483327725178503050954465496579029808248184888643303185618308763733997615320553015405690744070807211482128071409950264282977319266191487261103502555226760220546234498149656711925740338844829729591708920986465468222690445385909699330163332174162914930634801187698298857635172034342938401155529241042690638568002082750744828569527268706825896607293470070519508750808672297737672589325372550690338703159708325314400119604285890879839665019432945211057077954724304391457748178158342102678647638675140789772803542969160681336018722064996045749041652415620028221116584104045162570850741666319496552
c = 7671911440878532433667489121828814206592905285054408925890957479601943566414379514033874765501550498736499255991516070982428829754215761379185332958698647996593036343058593367027949092535270126555289066891682542858251520630282615219113992893538755011967908072085016169215108544137084022073238869205354768601011253598570094011049868057911782544371039850541623949887912456762381409507704839535353567221746097905226925531782002899815658215746488155191552172645625894574414478454147438702085549693734515817389107129970039603841526722724017093665065816242864251172298040114963542718789445480202872102034318566489282326224

def decrypt(f, g, c):
a = c * f % p
return a * pow(f, -1, g) % g

M = matrix([[1, h], [0, p]])
L = M.LLL()
shortest_vector = L[0]
if shortest_vector < 0:
shortest_vector = -shortest_vector
f, g = shortest_vector
m = decrypt(f, g, c)
print(f'{m = }')

多项式环与格

多项式环

  1. \(R(N)\) 表示最高次数不超过 \(N-1\) 的所有整系数多项式集合,即 \(a=a_0+a_1x+\cdots+a_{N-1}x^{N-1}\)

  2. \(R\) 上的加法定义为:\(a+b=(a_0+b_0)+(a_1+b_1)x+\cdots+(a_{N-1}+b_{N-1})x^{N-1}\)

  3. \(R\) 上的乘法定义为:\(a\times b=c_0+c_1x+\cdots+c_{N-1}x^{N-1}\),其中第 \(k\) 阶系数为: \[ \sum_{i+j\equiv k\pmod{N}}a_ib_i \]

由以上定义可知,\(R\) 的加法和乘法运算构成一个环。

概述

基于前面的加密算法,我们不难发现破解该算法的主要核心点就在于格攻击上,如果说我们将格的维度提高,那么就很自然的提高了攻击成本。

我们考虑将原本的整数环转化为多项式环,在多项式环进行相应的加解密运算,此时如果还想要对其进行攻击,就必须构建格 \[ \mathscr{L}=\begin{bmatrix} \lambda I&H\newline 0&qI \end{bmatrix}\newline H=\begin{bmatrix} h_{0} & h_{1} & \ldots & h_{N-1} \\ h_{N-1} & h_{0} & \ldots & h_{N-2} \\ \vdots & \vdots & \ddots & \vdots \\ h_{1} & h_{2} & \ldots & h_{0} \\ \end{bmatrix} \] 其中,\(H\) 是有关多项式 \(h\) 的系数循环矩阵,\(I\) 为单位矩阵,\(\lambda\) 为维持格大小的系数,\(q\) 是 NTRU 的参数。

这样很直接地增加了格的维度,使得时间复杂度显著提高,格规约结果变得不稳定,同时还能降低加密运算的复杂度(小的多项式系数照样有一定的安全强度)。

算法

这样根据上面的想法,我们可以改变一下原来算法的描述。

首先我们需要一些公共参数来确定多项式的阶和多项式系数的大小,并且做一定的约束。

公共参数

  • \(N\) 是一个素数,要求其足够大以扩展格攻击中格的维度,\(N-1\) 就是多项式环的阶,同时这也是公钥之一。
  • \(p,q\) 是互素的数,且 \(q\gg p\),这是为了保证密钥中的多项式 \(f\) 可以进行正常运算,选取的大小可以偏小,但不能太小(\(p\) 一般选择 3,\(q\) 一般选取 2 的倍数,例如 65536)。
  • 向量空间 \(L_f,L_g,L_r,L_m\)\(R(N)\)

密钥生成

  1. 随机选取两个多项式 \(f\in L_f,g\in L_g\),其中保证多项式 \(f\) 在模 \(p\) 和模 \(q\)​ 下均可逆,其逆元分别表示为 \(f_p,f_q\),即满足 \[ f_p\cdot f\equiv1\pmod{p}\newline f_q\cdot f\equiv1\pmod{q} \]

  2. 计算 \(h\equiv f_q\cdot g\pmod{q}\)

  3. 那么公钥为 \((h,N)\),私钥为 \((f,g)\)

加密

  1. 将信息映射到 \(L_m\) 上,定义为 \(m\in L_m\)

  2. 随机选取 \(r\in L_r\)

  3. 用公钥 \(h\) 对消息进行加密(这里有可能并不会与 \(p\) 相乘,而是计算 \(h=p\cdot f_q\cdot g\pmod{q}\)): \[ e\equiv p\cdot r\cdot h+m\pmod{q} \]

  4. \(e\) 即为密文

解密

  1. 计算 \(a\equiv f\cdot e\pmod{q}\),其中 \(a\) 的系数选在 \(\displaystyle(-\frac{q}{2},\frac{q}{2})\) 区间内,平衡系数方便计算。
  2. 计算 \(m=f_p\cdot a\pmod{p}\)

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from random import shuffle
from sage.all import *
from Crypto.Util.number import *
Zx = PolynomialRing(ZZ, 'x')
x = Zx.gen()

def convolution(f, g, R):
return (f * g) % R


def balancedmod(f, q, R):
g = list(map(lambda x: ((x + q//2) % q) - q//2, f.list()))
return Zx(g) % R


def random_poly(n, d1, d2):
assert d1 + d2 <= n
result = d1 * [1] + d2 * [-1] + (n - d1 - d2) * [0]
shuffle(result)
return Zx(result)


def invert_poly_mod_prime(f, R, p):
T = Zx.change_ring(Integers(p)).quotient(R)
return Zx(lift(1 / T(f)))


def invert_poly_mod_powerof2(f, R, q): # Hensel Lemma
g = invert_poly_mod_prime(f, R, 2)
e = log(q, 2)
for i in range(1, e):
g = ((2 * g - f * g ** 2) % R) % q
return g


class NTRUCipher:
def __init__(self, N, p, q, d):
self.N = N
self.p = p
self.q = q
self.d = d
self.R = x ** N - 1
# key generation
self.g = random_poly(self.N, d, d)

while True:
try:
self.f = random_poly(self.N, d + 1, d)
self.fp = invert_poly_mod_prime(self.f, self.R, self.p)
self.fq = invert_poly_mod_powerof2(self.f, self.R, self.q)
break
except:
pass
# 这里在计算公钥时就已经乘以p了
self.h = balancedmod(self.p * convolution(self.fq, self.g, self.R), self.q, self.R)
def getPubKey(self):
return self.h
def encrypt(self, m):
r = random_poly(self.N, self.d, self.d)
# 在这里并没有乘以p是因为公钥计算时已经乘进去了
return balancedmod(convolution(self.h, r, self.R) + m, self.q, self.R)

def decrypt(self, c):
a = balancedmod(convolution(c, self.f, self.R), self.q, self.R)
return balancedmod(convolution(a, self.fp, self.R), self.p, self.R)

def encode(self, val):
poly = 0
for i in range(self.N):
poly += ((val % self.p) - self.p // 2) * (x ** i)
val //= self.p
return poly

def decode(self, poly):
result = 0
ll = poly.list()
for idx, val in enumerate(ll):
result += (val + self.p // 2) * (self.p ** idx)
return result

def poly_from_list(self, l: list):
return Zx(l)


if __name__ == '__main__':
N = 100
d = 3
p = 3
q = 512
flag = b'flag{qsdz_yyds}'

cipher = NTRUCipher(N, p, q, d)
print("[PubKey]---------")
h = cipher.getPubKey()
print(f'{h = }')
msg = bytes_to_long(flag)
encode_msg = cipher.encode(msg)
e = cipher.encrypt(encode_msg)
print("[Cipher]---------")
print(f'{e = }')
mm = cipher.decrypt(e)
decode_msg = cipher.decode(mm)
assert decode_msg == msg

解密原理

\[ \begin{align*} a&\equiv f\cdot e\\ &\equiv f\cdot p\cdot r+f\cdot m\pmod{q}\\ &\equiv f\cdot p\cdot r \cdot f_qg+f\cdot m\pmod{q}\\ &\equiv p\cdot r\cdot g+f\cdot m\pmod{q} \end{align*} \]

其中 \(a\) 的系数选在 \(\displaystyle(-\frac{q}{2},\frac{q}{2})\) 区间内,故 \[ m=f_p\cdot a\equiv f_p\cdot p\cdot r\cdot g+f_p\cdot f\cdot m\pmod{p}\equiv m\pmod{p} \]

格攻击

如概述所说,我们需要构建一个格来对其进行攻击,当 \(N\) 较小时,可以很轻松地进行攻击,可以证明 \((f_0,\cdots,f_N,g_0,\cdots,g_N)\) 在格上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
N = 7
d = 2
p = 3
q = 256

# 如果h=p*g*fp,那么这里需要乘以p的逆元
hbar = Zx(h * pow(p, -1, q))
H = matrix.circulant(hbar.padded_list(N))
E = matrix.identity(N)
Q = q * E
M = matrix.block([E, H, 0, Q], nrows=2)
L = M.LLL()

cipher = NTRUCipher(N, p, q, d)
possible_ms = []
for shortest_vector in L:
shortest_vector = shortest_vector % p
f_coefficients = shortest_vector[:N]
g_coefficients = shortest_vector[N:]
f = Zx(list(f_coefficients))
g = Zx(list(g_coefficients))
cipher.f = f
possible_ms.append(cipher.decrypt(e))
print(possible_ms)

NTRU 求解

明文爆破

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import itertools
from Crypto.Hash import SHA3_256
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from secret import flag

# parameters
N = 10
p = 3
q = 512
d = 3
assert q>(6*d+1)*p

R.<x> = ZZ[]

#d1 1s and #d2 -1s
def T(d1, d2):
assert N >= d1+d2
s = [1]*d1 + [-1]*d2 + [0]*(N-d1-d2)
shuffle(s)
return R(s)

def invertModPrime(f, p):
Rp = R.change_ring(Integers(p)).quotient(x^N-1)
return R(lift(1 / Rp(f)))

def convolution(f, g):
return (f*g) % (x^N-1)

def liftMod(f, q):
g = list(((f[i] + q//2) % q) - q//2 for i in range(N))
return R(g)

def polyMod(f, q):
g = [f[i]%q for i in range(N)]
return R(g)

def invertModPow2(f, q):
assert q.is_power_of(2)
g = invertModPrime(f,2)
while True:
r = liftMod(convolution(g,f),q)
if r == 1: return g
g = liftMod(convolution(g,2 - r),q)

def genMessage():
result = list(randrange(p) - 1 for j in range(N))
return R(result)

def genKey():
while True:
try:
f = T(d+1, d)
g = T(d, d)
Fp = polyMod(invertModPrime(f, p), p)
Fq = polyMod(invertModPow2(f, q), q)
break
except:
continue
h = polyMod(convolution(Fq, g), q)
return h, (f, g)

def encrypt(m, h):
e = liftMod(p*convolution(h, T(d, d)) + m, q)
return e

# Step 1
h, secret = genKey()
m = genMessage()
e = encrypt(m, h)

print('h = %s' % h)
print('e = %s' % e)

result = list(randrange(p) - 1 for j in range(N))
for i in itertools.product(range(-1, p), repeat=N):
i = list(i)
m = R(i)
if encrypt(m, h) == e:
print(m)
break
print('end')

# Step 2
sha3 = SHA3_256.new()
sha3.update(bytes(str(m).encode('utf-8')))
key = sha3.digest()

cypher = AES.new(key, AES.MODE_ECB)
c = cypher.encrypt(pad(flag, 32))
print('c = %s' % c)

输出为

1
2
3
h = 39*x^9 + 60*x^8 + 349*x^7 + 268*x^6 + 144*x^5 + 469*x^4 + 449*x^3 + 165*x^2 + 248*x + 369
e = -144*x^9 - 200*x^8 - 8*x^7 + 248*x^6 + 85*x^5 + 102*x^4 + 167*x^3 + 30*x^2 - 203*x - 78
c = b'\xb9W\x8c\x8b\x0cG\xde\x7fl\xf7\x03\xbb9m\x0c\xc4L\xfe\xe9Q\xad\xfd\xda!\x1a\xea@}U\x9ay4\x8a\xe3y\xdf\xd5BV\xa7\x06\xf9\x08\x96="f\xc1\x1b\xd7\xdb\xc1j\x82F\x0b\x16\x06\xbcJMB\xc8\x80'

由于明文多项式较小,直接爆破即可:

首先各种多项式 m 的值

1
2
3
4
5
fp = open('m.txt', 'w')
for i in itertools.product(range(p), repeat=N):
r = list(j-1 for j in i)
m = R(r)
print(m, file=fp)

然后直接爆破即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from Crypto.Hash import SHA3_256
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad

with open('m.txt', 'r') as f:
data = f.read().splitlines()
print(data)
c = b'\xb9W\x8c\x8b\x0cG\xde\x7fl\xf7\x03\xbb9m\x0c\xc4L\xfe\xe9Q\xad\xfd\xda!\x1a\xea@}U\x9ay4\x8a\xe3y\xdf\xd5BV\xa7\x06\xf9\x08\x96="f\xc1\x1b\xd7\xdb\xc1j\x82F\x0b\x16\x06\xbcJMB\xc8\x80'
for m in data:
sha3 = SHA3_256.new()
sha3.update(bytes(str(m).encode('utf-8')))
key = sha3.digest()

cypher = AES.new(key, AES.MODE_ECB)
k = cypher.decrypt(c)
if b'DASCTF' in k:
print(f"{k=}")