From 6edf6b7e23803ab27c4190987184657e39a38e02 Mon Sep 17 00:00:00 2001 From: Christian Ulrich Date: Tue, 17 Nov 2020 20:40:30 +0100 Subject: [PATCH] add UDP hole punching (untested) --- message.nim | 55 ++++++++++++++ port_prediction.nim | 165 ++++++++++++++++++++++++++++++++++++++++++ puncher.nim | 84 +++++++++++++++++++++ quicp2p.nim | 157 +++++++++++++++++++++++++++------------- server_connection.nim | 104 ++++++++++++++++++++++++++ 5 files changed, 515 insertions(+), 50 deletions(-) create mode 100644 message.nim create mode 100644 port_prediction.nim create mode 100644 puncher.nim create mode 100644 server_connection.nim diff --git a/message.nim b/message.nim new file mode 100644 index 0000000..7782664 --- /dev/null +++ b/message.nim @@ -0,0 +1,55 @@ +import strutils +from net import IpAddress, parseIpAddress, Port, `$` + +proc parseField(input: string, output: var string) = + output = input + +proc parseField[T: SomeUnsignedInt](input: string, output: var T) = + let parsed = parseUInt(input) + if parsed > T.high: + raise newException(ValueError, "Unsigned integer out of range") + output = parsed.T + +proc parseField(input: string, output: var IpAddress) = + output = parseIpAddress(input) + +proc parseField(input: string, output: var Port) = + var portNumber: uint16 + parseField(input, portNumber) + output = Port(portNumber) + +proc parseField[S, T](input: string, output: var array[S, T]) = + let parts = input.split(",", S.high) + if parts.len != S.high + 1: + raise newException(ValueError, "Array has wrong length") + for i in 0 .. S.high: + parseField(parts[i], output[i]) + +proc parseField[T](input: string, output: var seq[T]) = + let parts = input.split(",") + if parts.len < 1: + raise newException(ValueError, "Sequence is empty") + output = newSeq[T](parts.len) + for i in 0 .. parts.len - 1: + parseField(parts[i], output[i]) + +proc parseField[T: tuple | object](input: string, output: var T) = + var fieldCount = 0 + for _ in output.fields: + fieldCount = fieldCount + 1 + let args = input.parseArgs(fieldCount) + var i = 0 + for value in output.fields: + parseField(args[i], value) + i.inc + +proc parseArgs*(input: string, count: int, optionalCount = 0): seq[string] = + assert(optionalCount <= count) + result = input.split("|", count - 1) + if result.len < count: + if result.len < count - optionalCount: + raise newException(ValueError, "invalid message") + result.add(repeat("", count - result.len)) + +proc parseMessage*[T: tuple | object](input: string): T = + parseField(input, result) diff --git a/port_prediction.nim b/port_prediction.nim new file mode 100644 index 0000000..8369bc1 --- /dev/null +++ b/port_prediction.nim @@ -0,0 +1,165 @@ +import algorithm +import net +import sequtils +import unittest + +const RandomPortCount = 1000 + +proc min(a, b: uint16): uint16 = + min(a.int32, b.int32).uint16 + +proc toUint16(p: Port): uint16 = uint16(p) + +proc toPort(u: uint16): Port = Port(u) + +proc addOffset(port: uint16, offset: uint16, minValue = 1024'u16, + maxValue = uint16.high): uint16 = + assert(port >= minValue) + assert(port <= maxValue) + let distanceToMaxValue = maxValue - port + if distanceToMaxValue < offset: + return minValue + offset - distanceToMaxValue - 1 + return port + offset + +proc subtractOffset(port: uint16, offset: uint16, minValue = 1024'u16, + maxValue = uint16.high): uint16 = + assert(port >= minValue) + assert(port <= maxValue) + let distanceToMinValue = port - minValue + if distanceToMinValue < offset: + return maxValue - offset + distanceToMinValue + 1 + return port - offset + +proc predictPortRange*(localPort: Port, probedPorts: seq[Port]): seq[Port] = + if probedPorts.len == 0: + # No probed ports, so our only guess can be that the NAT is a cone-type NAT + # and the port mapping preserves the local Port. + return @[localPort] + let localPortUint = localPort.uint16 + let probedPortsUint = probedPorts.map(toUint16) + if probedPorts.len == 1: + # Only one server was used for probing, so we cannot know if the NAT is + # symmetric or not. We are trying the probed port (assuming cone-type NAT) + # and the next port in a progressive sequence if applicable (assuming + # symmetric NAT with progressive port mapping). + result.add(probedPorts[0]) + if probedPortsUint[0] > localPortUint: + let offset = probedPortsUint[0] - localPortUint + result.add(Port(probedPortsUint[0].addOffset(offset))) + elif probedPortsUint[0] < localPortUint: + let offset = localPortUint - probedPortsUint[0] + result.add(Port(probedPortsUint[0].subtractOffset(offset))) + return + let deduplicatedPorts = probedPortsUint.deduplicate() + if deduplicatedPorts.len() == 1: + # It looks like the NAT is a cone-type NAT. + return deduplicatedPorts.map(toPort) + let probedPortsSorted = probedPortsUint.sorted() + let minPort = probedPortsSorted[probedPortsSorted.minIndex()] + let maxPort = probedPortsSorted[probedPortsSorted.maxIndex()] + var minDistance = uint16.high() + var maxDistance = uint16.low() + for i in 1 .. probedPortsSorted.len() - 1: + # FIXME: use rotated distance + let distance = probedPortsSorted[i] - probedPortsSorted[i - 1] + minDistance = min(minDistance, distance) + maxDistance = max(maxDistance, distance) + if maxDistance < 10: + if probedPortsUint.isSorted(Ascending): + # assume symmetric NAT with positive-progressive port mapping + if minDistance == maxDistance: + return @[Port(maxPort.addOffset(maxDistance))] + else: + for i in countup(0'u16, maxDistance): + result.add(Port(minPort.addOffset(i))) + return + if probedPortsUint.isSorted(Descending): + # assume symmetric NAT with negative-progressive port mapping + if minDistance == maxDistance: + return @[Port(minPort.subtractOffset(maxDistance))] + else: + for i in countup(0'u16, maxDistance): + result.add(Port(maxPort.subtractOffset(i))) + return + # assume symmetric NAT with random port mapping + let portRange = maxPort - minPort + let first = if portRange > RandomPortCount: + minPort + else: + let notCovered = RandomPortCount - portRange + max(minPort - notCovered shr 1, 1024) + let last = first + RandomPortCount + for i in first .. last: + result.add(Port(i)) + +suite "port prediction tests": + test "single port": + let predicted = predictPortRange(Port(1234), @[]) + check(predicted == @[Port(1234)]) + + test "single probe equal": + let predicted = predictPortRange(Port(1234), @[Port(1234)]) + check(predicted == @[Port(1234)]) + + test "single probe positive-progressive": + let predicted = predictPortRange(Port(1234), @[Port(1236)]) + check(predicted == @[Port(1236), Port(1238)]) + + test "single probe negative-progressive": + let predicted = predictPortRange(Port(1234), @[Port(1232)]) + check(predicted == @[Port(1232), Port(1230)]) + + test "all equal": + let predicted = predictPortRange(Port(1234), @[Port(1234), Port(1234)]) + check(predicted == @[Port(1234)]) + + test "positive-progressive, offset 1": + let predicted = predictPortRange(Port(1234), @[Port(2034), Port(2035)]) + check(predicted == @[Port(2036)]) + + test "positive-progressive, offset 9": + let predicted = predictPortRange(Port(1234), @[Port(2034), Port(2043)]) + check(predicted == @[Port(2052)]) + + test "negative-progressive, offset 1": + let predicted = predictPortRange(Port(1234), @[Port(1100), Port(1099)]) + check(predicted == @[Port(1098)]) + + test "negative-progressive, offset 9": + let predicted = predictPortRange(Port(1234), @[Port(1100), Port(1091)]) + check(predicted == @[Port(1082)]) + + test "positive-progressive, 3 probed ports, low offset": + let predicted = predictPortRange(Port(1234), @[Port(2000), Port(2000), Port(2002)]) + check(predicted == @[Port(2000), Port(2001), Port(2002)]) + + test "negative-progressive, 3 probed ports, low offset": + let predicted = predictPortRange(Port(1234), @[Port(2002), Port(2000), Port(2000)]) + check(predicted == @[Port(2002), Port(2001), Port(2000)]) + + test "high port, positive-progressive, offset 1": + let predicted = predictPortRange(Port(1234), @[Port(65534), Port(65535)]) + check(predicted == @[Port(1024)]) + + test "high port, positive-progressive, offset 9": + let predicted = predictPortRange(Port(1234), @[Port(65520), Port(65529)]) + check(predicted == @[Port(1026)]) + + test "low port, negative-progressive, offset 1": + let predicted = predictPortRange(Port(1234), @[Port(1025), Port(1024)]) + check(predicted == @[Port(65535)]) + + test "low port, negative-progressive, offset 9": + let predicted = predictPortRange(Port(1234), @[Port(1039), Port(1030)]) + check(predicted == @[Port(65533)]) + + test "random mapping, distance > RandomPortCount": + let predicted = predictPortRange(Port(1234), @[Port(3546), Port(7624)]) + check(predicted == toSeq(countup(3546'u16, 3546'u16 + RandomPortCount)).map(toPort)) + + test "random mapping, distance < RandomPortCount": + let centerPort = 30000'u16 + let minPort = centerPort - RandomPortCount.uint16 shr 1 + 1 + let maxPort = centerPort + RandomPortCount.uint16 shr 1 - 1 + let predicted = predictPortRange(Port(centerPort), @[Port(minPort), Port(maxPort)]) + check(predicted == toSeq(countup(minPort - 1, maxPort + 1)).map(toPort)) diff --git a/puncher.nim b/puncher.nim new file mode 100644 index 0000000..d985806 --- /dev/null +++ b/puncher.nim @@ -0,0 +1,84 @@ +import asyncdispatch, asyncnet, net, port_prediction + +from nativesockets import SockAddr, SockAddr_storage, SockLen +from sequtils import any + +type + Attempt = object + ## A hole punching attempt. + srcPort: Port + dstIp: IpAddress + dstPorts: seq[Port] + future: Future[Port] + + Puncher* = ref object + sock: AsyncSocket + attempts: seq[Attempt] + + PunchHoleError* = object of ValueError + +const Timeout = 3000 + +proc `==`(a, b: Attempt): bool = + ## ``==`` for hole punching attempts. + ## + ## Two hole punching attempts are considered equal if their ``srcPort`` and + ## ``dstIp`` are equal and their ``dstPorts`` overlap. + a.srcPort == b.srcPort and a.dstIp == b.dstIp and + a.dstPorts.any(proc (p: Port): bool = p in b.dstPorts) + +proc initPuncher*(sock: AsyncSocket): Puncher = + Puncher(sock: sock) + +proc punch(puncher: Puncher, peerIp: IpAddress, peerPort: Port, + peerProbedPorts: seq[Port], msg: string): Future[Port] {.async.} = + let punchFuture = newFuture[Port]("punch") + let predictedDstPorts = predictPortRange(peerPort, peerProbedPorts) + let (_, myPort) = puncher.sock.getLocalAddr() + let attempt = Attempt(srcPort: myPort, dstIp: peerIp, + dstPorts: predictedDstPorts, future: punchFuture) + if puncher.attempts.contains(attempt): + raise newException(PunchHoleError, + "hole punching for given parameters already active") + puncher.attempts.add(attempt) + var peerAddr: Sockaddr_storage + var peerSockLen: SockLen + try: + for dstPort in attempt.dstPorts: + toSockAddr(attempt.dstIp, dstPort, peerAddr, peerSockLen) + # TODO: replace asyncdispatch.sendTo with asyncnet.sendTo (Nim 1.4 required) + await sendTo(puncher.sock.getFd().AsyncFD, msg.cstring, msg.len, + cast[ptr SockAddr](addr peerAddr), peerSockLen) + await punchFuture or sleepAsync(Timeout) + if punchFuture.finished(): + result = punchFuture.read() + else: + raise newException(PunchHoleError, "timeout") + except OSError as e: + raise newException(PunchHoleError, e.msg) + +proc initiate*(puncher: Puncher, peerIp: IpAddress, peerPort: Port, + peerProbedPorts: seq[Port]): Future[Port] = + punch(puncher, peerIp, peerPort, peerProbedPorts, "SYN") + +proc respond*(puncher: Puncher, peerIp: IpAddress, peerPort: Port, + peerProbedPorts: seq[Port]): Future[Port] = + punch(puncher, peerIp, peerPort, peerProbedPorts, "ACK") + +proc handleMsg*(puncher: Puncher, msg: string, peerIp: IpAddress, + peerPort: Port) = + ## Handles an incoming UDP message which may complete the Futures returned by + ## ``initiate`` and ``respond``. + let (_, myPort) = puncher.sock.getLocalAddr() + let query = Attempt(srcPort: myPort, dstIp: peerIp, dstPorts: @[peerPort]) + let i = puncher.attempts.find(query) + if i != -1: + puncher.attempts[i].future.complete(peerPort) + puncher.attempts.del(i) + +proc handleMsg*(puncher: Puncher, msg: string, + peerAddr: SockAddr | Sockaddr_storage, peerSockLen: SockLen) = + var peerIp: IpAddress + var peerPort: Port + fromSockAddr(peerAddr, peerSockLen, peerIp, peerPort) + handleMsg(puncher, msg, peerIp, peerPort) diff --git a/quicp2p.nim b/quicp2p.nim index fae2e5a..6cdcf6a 100644 --- a/quicp2p.nim +++ b/quicp2p.nim @@ -4,11 +4,13 @@ import asyncdispatch import asyncnet import base32 import certificate +import message import net import os import openssl_additional import picotls/picotls import picotls/openssl as ptls_openssl +import puncher import quicly/quicly import quicly/cid import quicly/constants @@ -16,6 +18,8 @@ import quicly/defaults import quicly/recvstate import quicly/sendstate import quicly/streambuf +import random +import server_connection import strformat import strutils @@ -35,18 +39,15 @@ from openssl import PSTACK, d2i_X509 -const serverCertChainPath = "./certs/server-certchain.pem" -const serverKeyPath = "./certs/server-cert.key" -const clientCertChainPath = "./certs/client-certchain.pem" -const clientKeyPath = "./certs/client-cert.key" - type Connection = ref object conn: ptr quicly_conn_t certs: seq[Certificate] + peerId: string QuicP2PContext = ref object sock: AsyncSocket + puncher: Puncher streamOpen: quicly_stream_open_t nextCid: quicly_cid_plaintext_t signCertCb: ptls_openssl_sign_certificate_t @@ -55,6 +56,16 @@ type quiclyCtx: quicly_context_t connections: seq[Connection] +const serverCertChainPath = "./certs/server-certchain.pem" +const serverKeyPath = "./certs/server-cert.key" +const clientCertChainPath = "./certs/client-certchain.pem" +const clientKeyPath = "./certs/client-cert.key" + +const rendezvousServers: seq[tuple[hostname: string, port: Port]] = @[ + ("strangeplace.net", Port(5320)), + ("ulrich.earth", Port(5320)) +] + proc getRelativeTimeout(ctx: QuicP2PContext): int32 = ## Obtain the absolute int64 timeout from quicly and convert it to the ## relative int32 timeout expected by poll. @@ -197,7 +208,8 @@ proc verifyCerts(self: ptr ptls_verify_certificate_t, tls: ptr ptls_t, X509_STORE_free(store) X509_free(caCert) -proc initContext(sock: AsyncSocket, certChainPath: string, keyPath: string, +proc initContext(sock: AsyncSocket, puncher: Puncher, certChainPath: string, + keyPath: string, streamOpenCb: typeof(quicly_stream_open_t.cb)): QuicP2PContext = var tlsCtx = ptls_context_t(randomBytes: ptls_openssl_random_bytes, @@ -205,7 +217,7 @@ proc initContext(sock: AsyncSocket, certChainPath: string, keyPath: string, keyExchanges: ptls_openssl_key_exchanges, cipherSuites: ptls_openssl_cipher_suites) quicly_amend_ptls_context(addr tlsCtx) - result = QuicP2PContext(sock: sock, + result = QuicP2PContext(sock: sock, puncher: puncher, streamOpen: quicly_stream_open_t(cb: streamOpenCb), verifyCertsCb: ptls_verify_certificate_t(cb: verifyCerts), tlsCtx: tlsCtx, quiclyCtx: quicly_spec_context) @@ -223,10 +235,11 @@ proc initContext(sock: AsyncSocket, certChainPath: string, keyPath: string, EVP_PKEY_free(privateKey) result.tlsCtx.sign_certificate = addr result.signCertCb.super -proc addConnection(ctx: QuicP2PContext, connPtr: ptr quicly_conn_t) = +proc addConnection(ctx: QuicP2PContext, connPtr: ptr quicly_conn_t, + peerId: string) = assert(not connPtr.isNil) let data = quicly_get_data(connPtr) - var conn = Connection(conn: connPtr) + var conn = Connection(conn: connPtr, peerId: peerId) data[] = addr conn[] ctx.connections.add(conn) @@ -262,8 +275,25 @@ proc sendPackets(ctx: QuicP2PContext) = else: raise newException(ValueError, &"quicly_send returned {sendResult}") -proc handleMsg(ctx: QuicP2PContext, msg: string, peerAddr: ptr SockAddr, - isServer: bool) = +proc initiateQuicConnection(ctx: QuicP2PContext, peerId: string, + peerIp: IpAddress, peerPort: Port) = + var conn: ptr quicly_conn_t + var peerAddr: SockAddr_storage + var peerSockLen: SockLen + toSockAddr(peerIp, peerPort, peerAddr, peerSockLen) + let addressToken = ptls_iovec_init(nil, 0) + let connectResult = quicly_connect(addr conn, addr ctx.quiclyCtx, + peerId.cstring, addr peerAddr, nil, + addr ctx.nextCid, addressToken, nil, nil) + if connectResult != 0: + echo "quicly_connect failed: ", connectResult + return + var stream: ptr quicly_stream_t + discard quicly_open_stream(conn, addr stream, 0) + ctx.addConnection(conn, peerId) + +proc handleMsg(ctx: QuicP2PContext, msg: string, peerId: string, + peerAddr: ptr Sockaddr_storage, peerSockLen: SockLen) = var offset: csize_t = 0 while offset < msg.len().csize_t: var decoded: quicly_decoded_packet_t @@ -271,6 +301,10 @@ proc handleMsg(ctx: QuicP2PContext, msg: string, peerAddr: ptr SockAddr, cast[ptr uint8](msg.cstring), msg.len().csize_t, addr offset) if decode_result == csize_t.high: + # The puncher needs to be informed about this message because quicly not + # being able to decode it may indicate it's the peer's response to our + # initiate call. + ctx.puncher.handleMsg(msg, peerAddr[], peerSockLen) return var conn: ptr quicly_conn_t = nil for c in ctx.connections: @@ -279,12 +313,16 @@ proc handleMsg(ctx: QuicP2PContext, msg: string, peerAddr: ptr SockAddr, break if conn != nil: discard quicly_receive(conn, nil, peerAddr, addr decoded) - elif isServer: + elif peerId.len != 0: + # The puncher needs to be informed about this message because it may + # be the peer's response to our respond call. Quicly needs to be informed + # because we except the first QUIC handshake packet in it. + ctx.puncher.handleMsg(msg, peerAddr[], peerSockLen) discard quicly_accept(addr conn, addr ctx.quiclyCtx, nil, peerAddr, addr decoded, nil, addr ctx.nextCid, nil) - ctx.addConnection(conn) + ctx.addConnection(conn, peerId) -proc receive(ctx: QuicP2PContext, isServer: bool) {.async.} = +proc receive(ctx: QuicP2PContext, peerId: string) {.async.} = while true: # TODO: replace asyncdispatch.recvFromInto with asyncnet.recvFrom (Nim 1.4 required) var msg = newString(BufferSize) @@ -295,52 +333,71 @@ proc receive(ctx: QuicP2PContext, isServer: bool) {.async.} = addr peerAddrLen) msg.setLen(msgLen) if msg.len > 0: - handleMsg(ctx, msg, cast[ptr SockAddr](addr peerAddr), isServer) + handleMsg(ctx, msg, peerId, addr peerAddr, peerAddrLen) + +proc handleNotification(ctx: QuicP2PContext, notification: NotifyPeer) + {.async.} = + let _ = await ctx.puncher.respond(notification.dstIp, notification.dstPort, + notification.probedDstPorts) + +proc runApp(ctx: QuicP2PContext, srcPort: Port, peerId: string) {.async.} = + let serverConn = await initServerConnection(rendezvousServers[0].hostname, + rendezvousServers[0].port, + srcPort, rendezvousServers) + asyncCheck handleServerMessages(serverConn) + asyncCheck receive(ctx, peerId) + + if peerId.len == 0: + # We are the responder + let probedPorts = serverConn.probedSrcPorts.join(",") + let req = &"{ctx.getPeerId()}|{serverConn.probedIp}|{srcPort}|{probedPorts}" + discard await serverConn.sendRequest("register", req) + while true: + let (hasData, data) = await serverConn.peerNotifications.read() + if not hasData: + break + try: + let msg = parseMessage[NotifyPeer](data) + # FIXME: check if we want to receive messages from the sender + echo "received message from ", msg.sender + asyncCheck handleNotification(ctx, msg) + except ValueError as e: + echo e.msg + discard + + else: + # We are the initiator + let serverResponse = await serverConn.sendRequest("get-peerinfo", peerId) + let peerInfo = parseMessage[OkGetPeerInfo](serverResponse) + let myProbedPorts = serverConn.probedSrcPorts.join(",") + let peerProbedPorts = peerInfo.probedPorts.join(",") + let req = &"{ctx.getPeerId()}|{peerId}|{serverConn.probedIp}|{srcPort}|{myProbedPorts}|{peerInfo.ip}|{peerInfo.localPort}|{peerProbedPorts}" + discard await serverConn.sendRequest("notify-peer", req) + let peerPort = await ctx.puncher.initiate(peerInfo.ip, peerInfo.localPort, + peerInfo.probedPorts) + initiateQuicConnection(ctx, peerId, peerInfo.ip, peerPort) proc main() = var ctx: QuicP2PContext let sock = newAsyncSocket(sockType = SOCK_DGRAM, protocol = IPPROTO_UDP, buffered = false) + randomize() + let srcPort = rand(Port(1024) .. Port.high) + sock.bindAddr(srcPort) + let puncher = initPuncher(sock) + case paramCount(): - of 1: - let portNumber = paramStr(1).parseUInt() - if portNumber > uint16.high: - usage() - quit(1) - sock.bindAddr(Port(portNumber)) - ctx = initContext(sock, serverCertChainPath, serverKeyPath, + of 0: + ctx = initContext(sock, puncher, serverCertChainPath, serverKeyPath, onServerStreamOpen) ctx.tlsCtx.require_client_authentication = 1 - asyncCheck receive(ctx, true) + asyncCheck runApp(ctx, srcPort, "") - of 2: - let hostname = paramStr(1) - let portNumber = paramStr(2).parseUInt() - if portNumber > uint16.high: - usage() - quit(1) - ctx = initContext(sock, clientCertChainPath, clientKeyPath, + of 1: + let peerId = paramStr(1) + ctx = initContext(sock, puncher, clientCertChainPath, clientKeyPath, onClientStreamOpen) - var conn: ptr quicly_conn_t - let hostent = getHostByName(hostname) - if hostent.addrList.len == 0: - echo "cannot resolve hostname ", hostname - quit(2) - var destAddr: Sockaddr_storage - var sockLen: SockLen - toSockAddr(parseIpAddress(hostent.addrList[0]), Port(portNumber), destAddr, - sockLen) - let addressToken = ptls_iovec_init(nil, 0) - let connectResult = quicly_connect(addr conn, addr ctx.quiclyCtx, - hostname.cstring, addr destAddr, nil, - addr ctx.nextCid, addressToken, nil, nil) - if connectResult != 0: - echo "quicly_connect failed: ", connectResult - quit(3) - var stream: ptr quicly_stream_t - discard quicly_open_stream(conn, addr stream, 0) - ctx.addConnection(conn) - asyncCheck receive(ctx, false) + asyncCheck runApp(ctx, srcPort, peerId) else: usage() diff --git a/server_connection.nim b/server_connection.nim new file mode 100644 index 0000000..c28e72d --- /dev/null +++ b/server_connection.nim @@ -0,0 +1,104 @@ +import asyncdispatch, asyncnet, message, net, tables, random, strformat + +type + Endpoint* = tuple[hostname: string, port: Port] + + ServerConnection* = ref object + sock: AsyncSocket + outMessages: TableRef[string, Future[string]] + peerNotifications*: FutureStream[string] + probedIp*: IpAddress + srcPort*: Port + probedSrcPorts*: seq[Port] + + ServerError* = object of ValueError + + OkGetPeerinfo* = object + ip*: IpAddress + localPort*: Port + probedPorts*: seq[Port] + + OkGetEndpoint* = object + ip*: IpAddress + port*: Port + + NotifyPeer* = object + sender*: string + recipient*: string + technique*: string + srcIp*: IpAddress + srcPort*: Port + probedSrcPorts*: seq[Port] + dstIp*: IpAddress + dstPort*: Port + probedDstPorts*: seq[Port] + extraArgs*: string + +proc getEndpoint(srcPort: Port, serverHostname: string, serverPort: Port): + Future[OkGetEndpoint] {.async.} = + let sock = newAsyncSocket() + var failCount = 0 + while true: + try: + sock.bindAddr(srcPort) + break + except OSError as e: + if failCount == 3: + raise e + failCount.inc + await sleepAsync(100) + await sock.connect(serverHostname, serverPort) + let id = rand(uint32) + await sock.send(&"get-endpoint|{id}\n") + let line = await sock.recvLine(maxLength = 400) + let args = line.parseArgs(3) + assert(args[0] == "ok") + assert(args[1] == $id) + result = parseMessage[OkGetEndpoint](args[2]) + let emptyLine = await sock.recvLine(maxLength = 400) + assert(emptyLine.len == 0) + sock.close() + +proc initServerConnection*(serverHostname: string, serverPort: Port, + srcPort: Port, probingServers: seq[Endpoint]): + Future[ServerConnection] {.async.} = + result.srcPort = srcPort + for s in probingServers: + let endpoint = await getEndpoint(srcPort, s.hostname, s.port) + # FIXME: what if we get get different IPs from different servers + result.probedIp = endpoint.ip + result.probedSrcPorts.add(endpoint.port) + result.sock = await asyncnet.dial(serverHostname, + serverPort) + result.outMessages = newTable[string, Future[string]]() + result.peerNotifications = newFutureStream[string]("initServerConnection") + +proc handleServerMessages*(conn: ServerConnection) {.async.} = + while true: + let line = await conn.sock.recvLine(maxLength = 400) + let args = line.parseArgs(3, 1) + case args[0]: + of "ok": + let future = conn.outMessages[args[1]] + conn.outMessages.del(args[1]) + future.complete(args[2]) + of "error": + let future = conn.outMessages[args[1]] + conn.outMessages.del(args[1]) + future.fail(newException(ServerError, args[2])) + of "notify-peer": + asyncCheck conn.peerNotifications.write(line.substr(args[0].len + 1)) + else: + raise newException(ValueError, "invalid server message") + +proc sendRequest*(connection: ServerConnection, command: string, + content: string): Future[string] = + result = newFuture[string]("sendRequest") + let id = $rand(uint32) + var request: string + if content.len != 0: + request = &"{command}|{id}|{content}\n" + else: + request = &"{command}|{id}\n" + asyncCheck connection.sock.send(request) + connection.outMessages[id] = result