-
Notifications
You must be signed in to change notification settings - Fork 0
/
vsock.py
81 lines (69 loc) · 2.66 KB
/
vsock.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import socket, threading
import enclave.app as enclave
import skrecovery.config as config
HEADER = 64
FORMAT = "utf-8"
BUFFER_SIZE = 600 * 1024 * 1024
DISCONNECT_MESSAGE = "<<EOT>>"
SERVER = socket.VMADDR_CID_ANY if config.is_nitro_env() else socket.gethostbyname(socket.gethostname())
SOCK_FAMILY = socket.AF_VSOCK if config.is_nitro_env() else socket.AF_INET
ADDR = (SERVER, config.VSOCK_PORT)
def server_create(address: tuple = None) -> socket.socket:
server = socket.socket(SOCK_FAMILY, socket.SOCK_STREAM)
server.bind(address if address else ADDR)
server.listen()
return server
def connect(address: tuple = None) -> socket.socket:
client: socket.socket = socket.socket(SOCK_FAMILY, socket.SOCK_STREAM)
# client.settimeout(60 * 5)
client.connect(address if address else ADDR)
return client
def disconnect(conn: socket.socket):
send(conn, DISCONNECT_MESSAGE)
conn.close()
def recv_fixed_msg(conn: socket.socket, msg_length: int):
msg = ''
while len(msg) < msg_length:
num_bytes = min(BUFFER_SIZE, msg_length - len(msg))
m = conn.recv(num_bytes).decode(FORMAT)
msg += m
return msg
def response_recv(conn: socket.socket) -> str:
data: str = ''
msg_length = conn.recv(HEADER).decode(FORMAT)
if msg_length:
msg_length = int(msg_length)
data = recv_fixed_msg(conn, msg_length)
return data
def server_handle_client_connection(conn: socket.socket, addr):
print(f"[NEW CONNECTION] {addr} connected.")
while True:
msg_length = conn.recv(HEADER).decode(FORMAT)
if msg_length:
msg_length = int(msg_length)
print('Message length: ', msg_length)
msg = recv_fixed_msg(conn, msg_length)
if msg == DISCONNECT_MESSAGE:
break
print('Processing request...')
res: str = enclave.run(req=msg)
send(conn=conn, msg=res)
conn.close()
print(f"[CONNECTION CLOSED] {addr} disconnected.")
def server_start(server: socket.socket):
print(f"[LISTENING] Server is listening on {SERVER}:{config.VSOCK_PORT}")
while True:
conn, addr = server.accept()
thread = threading.Thread(
target=server_handle_client_connection,
args=(conn, addr)
)
thread.start()
print(f"[ACTIVE CONNECTIONS] {threading.active_count() - 1}")
def send(conn: socket.socket, msg: str):
message = msg.encode(FORMAT)
msg_length = len(message)
send_length = str(msg_length).encode(FORMAT)
send_length += b" " * (HEADER - len(send_length))
conn.send(send_length)
conn.sendall(message)