add UDP hole punching (untested)
This commit is contained in:
parent
6aa2f46b08
commit
6edf6b7e23
|
@ -0,0 +1,55 @@
|
||||||
|
import strutils
|
||||||
|
from net import IpAddress, parseIpAddress, Port, `$`
|
||||||
|
|
||||||
|
proc parseField(input: string, output: var string) =
|
||||||
|
output = input
|
||||||
|
|
||||||
|
proc parseField[T: SomeUnsignedInt](input: string, output: var T) =
|
||||||
|
let parsed = parseUInt(input)
|
||||||
|
if parsed > T.high:
|
||||||
|
raise newException(ValueError, "Unsigned integer out of range")
|
||||||
|
output = parsed.T
|
||||||
|
|
||||||
|
proc parseField(input: string, output: var IpAddress) =
|
||||||
|
output = parseIpAddress(input)
|
||||||
|
|
||||||
|
proc parseField(input: string, output: var Port) =
|
||||||
|
var portNumber: uint16
|
||||||
|
parseField(input, portNumber)
|
||||||
|
output = Port(portNumber)
|
||||||
|
|
||||||
|
proc parseField[S, T](input: string, output: var array[S, T]) =
|
||||||
|
let parts = input.split(",", S.high)
|
||||||
|
if parts.len != S.high + 1:
|
||||||
|
raise newException(ValueError, "Array has wrong length")
|
||||||
|
for i in 0 .. S.high:
|
||||||
|
parseField(parts[i], output[i])
|
||||||
|
|
||||||
|
proc parseField[T](input: string, output: var seq[T]) =
|
||||||
|
let parts = input.split(",")
|
||||||
|
if parts.len < 1:
|
||||||
|
raise newException(ValueError, "Sequence is empty")
|
||||||
|
output = newSeq[T](parts.len)
|
||||||
|
for i in 0 .. parts.len - 1:
|
||||||
|
parseField(parts[i], output[i])
|
||||||
|
|
||||||
|
proc parseField[T: tuple | object](input: string, output: var T) =
|
||||||
|
var fieldCount = 0
|
||||||
|
for _ in output.fields:
|
||||||
|
fieldCount = fieldCount + 1
|
||||||
|
let args = input.parseArgs(fieldCount)
|
||||||
|
var i = 0
|
||||||
|
for value in output.fields:
|
||||||
|
parseField(args[i], value)
|
||||||
|
i.inc
|
||||||
|
|
||||||
|
proc parseArgs*(input: string, count: int, optionalCount = 0): seq[string] =
|
||||||
|
assert(optionalCount <= count)
|
||||||
|
result = input.split("|", count - 1)
|
||||||
|
if result.len < count:
|
||||||
|
if result.len < count - optionalCount:
|
||||||
|
raise newException(ValueError, "invalid message")
|
||||||
|
result.add(repeat("", count - result.len))
|
||||||
|
|
||||||
|
proc parseMessage*[T: tuple | object](input: string): T =
|
||||||
|
parseField(input, result)
|
|
@ -0,0 +1,165 @@
|
||||||
|
import algorithm
|
||||||
|
import net
|
||||||
|
import sequtils
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
const RandomPortCount = 1000
|
||||||
|
|
||||||
|
proc min(a, b: uint16): uint16 =
|
||||||
|
min(a.int32, b.int32).uint16
|
||||||
|
|
||||||
|
proc toUint16(p: Port): uint16 = uint16(p)
|
||||||
|
|
||||||
|
proc toPort(u: uint16): Port = Port(u)
|
||||||
|
|
||||||
|
proc addOffset(port: uint16, offset: uint16, minValue = 1024'u16,
|
||||||
|
maxValue = uint16.high): uint16 =
|
||||||
|
assert(port >= minValue)
|
||||||
|
assert(port <= maxValue)
|
||||||
|
let distanceToMaxValue = maxValue - port
|
||||||
|
if distanceToMaxValue < offset:
|
||||||
|
return minValue + offset - distanceToMaxValue - 1
|
||||||
|
return port + offset
|
||||||
|
|
||||||
|
proc subtractOffset(port: uint16, offset: uint16, minValue = 1024'u16,
|
||||||
|
maxValue = uint16.high): uint16 =
|
||||||
|
assert(port >= minValue)
|
||||||
|
assert(port <= maxValue)
|
||||||
|
let distanceToMinValue = port - minValue
|
||||||
|
if distanceToMinValue < offset:
|
||||||
|
return maxValue - offset + distanceToMinValue + 1
|
||||||
|
return port - offset
|
||||||
|
|
||||||
|
proc predictPortRange*(localPort: Port, probedPorts: seq[Port]): seq[Port] =
|
||||||
|
if probedPorts.len == 0:
|
||||||
|
# No probed ports, so our only guess can be that the NAT is a cone-type NAT
|
||||||
|
# and the port mapping preserves the local Port.
|
||||||
|
return @[localPort]
|
||||||
|
let localPortUint = localPort.uint16
|
||||||
|
let probedPortsUint = probedPorts.map(toUint16)
|
||||||
|
if probedPorts.len == 1:
|
||||||
|
# Only one server was used for probing, so we cannot know if the NAT is
|
||||||
|
# symmetric or not. We are trying the probed port (assuming cone-type NAT)
|
||||||
|
# and the next port in a progressive sequence if applicable (assuming
|
||||||
|
# symmetric NAT with progressive port mapping).
|
||||||
|
result.add(probedPorts[0])
|
||||||
|
if probedPortsUint[0] > localPortUint:
|
||||||
|
let offset = probedPortsUint[0] - localPortUint
|
||||||
|
result.add(Port(probedPortsUint[0].addOffset(offset)))
|
||||||
|
elif probedPortsUint[0] < localPortUint:
|
||||||
|
let offset = localPortUint - probedPortsUint[0]
|
||||||
|
result.add(Port(probedPortsUint[0].subtractOffset(offset)))
|
||||||
|
return
|
||||||
|
let deduplicatedPorts = probedPortsUint.deduplicate()
|
||||||
|
if deduplicatedPorts.len() == 1:
|
||||||
|
# It looks like the NAT is a cone-type NAT.
|
||||||
|
return deduplicatedPorts.map(toPort)
|
||||||
|
let probedPortsSorted = probedPortsUint.sorted()
|
||||||
|
let minPort = probedPortsSorted[probedPortsSorted.minIndex()]
|
||||||
|
let maxPort = probedPortsSorted[probedPortsSorted.maxIndex()]
|
||||||
|
var minDistance = uint16.high()
|
||||||
|
var maxDistance = uint16.low()
|
||||||
|
for i in 1 .. probedPortsSorted.len() - 1:
|
||||||
|
# FIXME: use rotated distance
|
||||||
|
let distance = probedPortsSorted[i] - probedPortsSorted[i - 1]
|
||||||
|
minDistance = min(minDistance, distance)
|
||||||
|
maxDistance = max(maxDistance, distance)
|
||||||
|
if maxDistance < 10:
|
||||||
|
if probedPortsUint.isSorted(Ascending):
|
||||||
|
# assume symmetric NAT with positive-progressive port mapping
|
||||||
|
if minDistance == maxDistance:
|
||||||
|
return @[Port(maxPort.addOffset(maxDistance))]
|
||||||
|
else:
|
||||||
|
for i in countup(0'u16, maxDistance):
|
||||||
|
result.add(Port(minPort.addOffset(i)))
|
||||||
|
return
|
||||||
|
if probedPortsUint.isSorted(Descending):
|
||||||
|
# assume symmetric NAT with negative-progressive port mapping
|
||||||
|
if minDistance == maxDistance:
|
||||||
|
return @[Port(minPort.subtractOffset(maxDistance))]
|
||||||
|
else:
|
||||||
|
for i in countup(0'u16, maxDistance):
|
||||||
|
result.add(Port(maxPort.subtractOffset(i)))
|
||||||
|
return
|
||||||
|
# assume symmetric NAT with random port mapping
|
||||||
|
let portRange = maxPort - minPort
|
||||||
|
let first = if portRange > RandomPortCount:
|
||||||
|
minPort
|
||||||
|
else:
|
||||||
|
let notCovered = RandomPortCount - portRange
|
||||||
|
max(minPort - notCovered shr 1, 1024)
|
||||||
|
let last = first + RandomPortCount
|
||||||
|
for i in first .. last:
|
||||||
|
result.add(Port(i))
|
||||||
|
|
||||||
|
suite "port prediction tests":
|
||||||
|
test "single port":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[])
|
||||||
|
check(predicted == @[Port(1234)])
|
||||||
|
|
||||||
|
test "single probe equal":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(1234)])
|
||||||
|
check(predicted == @[Port(1234)])
|
||||||
|
|
||||||
|
test "single probe positive-progressive":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(1236)])
|
||||||
|
check(predicted == @[Port(1236), Port(1238)])
|
||||||
|
|
||||||
|
test "single probe negative-progressive":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(1232)])
|
||||||
|
check(predicted == @[Port(1232), Port(1230)])
|
||||||
|
|
||||||
|
test "all equal":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(1234), Port(1234)])
|
||||||
|
check(predicted == @[Port(1234)])
|
||||||
|
|
||||||
|
test "positive-progressive, offset 1":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(2034), Port(2035)])
|
||||||
|
check(predicted == @[Port(2036)])
|
||||||
|
|
||||||
|
test "positive-progressive, offset 9":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(2034), Port(2043)])
|
||||||
|
check(predicted == @[Port(2052)])
|
||||||
|
|
||||||
|
test "negative-progressive, offset 1":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(1100), Port(1099)])
|
||||||
|
check(predicted == @[Port(1098)])
|
||||||
|
|
||||||
|
test "negative-progressive, offset 9":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(1100), Port(1091)])
|
||||||
|
check(predicted == @[Port(1082)])
|
||||||
|
|
||||||
|
test "positive-progressive, 3 probed ports, low offset":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(2000), Port(2000), Port(2002)])
|
||||||
|
check(predicted == @[Port(2000), Port(2001), Port(2002)])
|
||||||
|
|
||||||
|
test "negative-progressive, 3 probed ports, low offset":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(2002), Port(2000), Port(2000)])
|
||||||
|
check(predicted == @[Port(2002), Port(2001), Port(2000)])
|
||||||
|
|
||||||
|
test "high port, positive-progressive, offset 1":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(65534), Port(65535)])
|
||||||
|
check(predicted == @[Port(1024)])
|
||||||
|
|
||||||
|
test "high port, positive-progressive, offset 9":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(65520), Port(65529)])
|
||||||
|
check(predicted == @[Port(1026)])
|
||||||
|
|
||||||
|
test "low port, negative-progressive, offset 1":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(1025), Port(1024)])
|
||||||
|
check(predicted == @[Port(65535)])
|
||||||
|
|
||||||
|
test "low port, negative-progressive, offset 9":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(1039), Port(1030)])
|
||||||
|
check(predicted == @[Port(65533)])
|
||||||
|
|
||||||
|
test "random mapping, distance > RandomPortCount":
|
||||||
|
let predicted = predictPortRange(Port(1234), @[Port(3546), Port(7624)])
|
||||||
|
check(predicted == toSeq(countup(3546'u16, 3546'u16 + RandomPortCount)).map(toPort))
|
||||||
|
|
||||||
|
test "random mapping, distance < RandomPortCount":
|
||||||
|
let centerPort = 30000'u16
|
||||||
|
let minPort = centerPort - RandomPortCount.uint16 shr 1 + 1
|
||||||
|
let maxPort = centerPort + RandomPortCount.uint16 shr 1 - 1
|
||||||
|
let predicted = predictPortRange(Port(centerPort), @[Port(minPort), Port(maxPort)])
|
||||||
|
check(predicted == toSeq(countup(minPort - 1, maxPort + 1)).map(toPort))
|
|
@ -0,0 +1,84 @@
|
||||||
|
import asyncdispatch, asyncnet, net, port_prediction
|
||||||
|
|
||||||
|
from nativesockets import SockAddr, SockAddr_storage, SockLen
|
||||||
|
from sequtils import any
|
||||||
|
|
||||||
|
type
|
||||||
|
Attempt = object
|
||||||
|
## A hole punching attempt.
|
||||||
|
srcPort: Port
|
||||||
|
dstIp: IpAddress
|
||||||
|
dstPorts: seq[Port]
|
||||||
|
future: Future[Port]
|
||||||
|
|
||||||
|
Puncher* = ref object
|
||||||
|
sock: AsyncSocket
|
||||||
|
attempts: seq[Attempt]
|
||||||
|
|
||||||
|
PunchHoleError* = object of ValueError
|
||||||
|
|
||||||
|
const Timeout = 3000
|
||||||
|
|
||||||
|
proc `==`(a, b: Attempt): bool =
|
||||||
|
## ``==`` for hole punching attempts.
|
||||||
|
##
|
||||||
|
## Two hole punching attempts are considered equal if their ``srcPort`` and
|
||||||
|
## ``dstIp`` are equal and their ``dstPorts`` overlap.
|
||||||
|
a.srcPort == b.srcPort and a.dstIp == b.dstIp and
|
||||||
|
a.dstPorts.any(proc (p: Port): bool = p in b.dstPorts)
|
||||||
|
|
||||||
|
proc initPuncher*(sock: AsyncSocket): Puncher =
|
||||||
|
Puncher(sock: sock)
|
||||||
|
|
||||||
|
proc punch(puncher: Puncher, peerIp: IpAddress, peerPort: Port,
|
||||||
|
peerProbedPorts: seq[Port], msg: string): Future[Port] {.async.} =
|
||||||
|
let punchFuture = newFuture[Port]("punch")
|
||||||
|
let predictedDstPorts = predictPortRange(peerPort, peerProbedPorts)
|
||||||
|
let (_, myPort) = puncher.sock.getLocalAddr()
|
||||||
|
let attempt = Attempt(srcPort: myPort, dstIp: peerIp,
|
||||||
|
dstPorts: predictedDstPorts, future: punchFuture)
|
||||||
|
if puncher.attempts.contains(attempt):
|
||||||
|
raise newException(PunchHoleError,
|
||||||
|
"hole punching for given parameters already active")
|
||||||
|
puncher.attempts.add(attempt)
|
||||||
|
var peerAddr: Sockaddr_storage
|
||||||
|
var peerSockLen: SockLen
|
||||||
|
try:
|
||||||
|
for dstPort in attempt.dstPorts:
|
||||||
|
toSockAddr(attempt.dstIp, dstPort, peerAddr, peerSockLen)
|
||||||
|
# TODO: replace asyncdispatch.sendTo with asyncnet.sendTo (Nim 1.4 required)
|
||||||
|
await sendTo(puncher.sock.getFd().AsyncFD, msg.cstring, msg.len,
|
||||||
|
cast[ptr SockAddr](addr peerAddr), peerSockLen)
|
||||||
|
await punchFuture or sleepAsync(Timeout)
|
||||||
|
if punchFuture.finished():
|
||||||
|
result = punchFuture.read()
|
||||||
|
else:
|
||||||
|
raise newException(PunchHoleError, "timeout")
|
||||||
|
except OSError as e:
|
||||||
|
raise newException(PunchHoleError, e.msg)
|
||||||
|
|
||||||
|
proc initiate*(puncher: Puncher, peerIp: IpAddress, peerPort: Port,
|
||||||
|
peerProbedPorts: seq[Port]): Future[Port] =
|
||||||
|
punch(puncher, peerIp, peerPort, peerProbedPorts, "SYN")
|
||||||
|
|
||||||
|
proc respond*(puncher: Puncher, peerIp: IpAddress, peerPort: Port,
|
||||||
|
peerProbedPorts: seq[Port]): Future[Port] =
|
||||||
|
punch(puncher, peerIp, peerPort, peerProbedPorts, "ACK")
|
||||||
|
|
||||||
|
proc handleMsg*(puncher: Puncher, msg: string, peerIp: IpAddress,
|
||||||
|
peerPort: Port) =
|
||||||
|
## Handles an incoming UDP message which may complete the Futures returned by
|
||||||
|
## ``initiate`` and ``respond``.
|
||||||
|
let (_, myPort) = puncher.sock.getLocalAddr()
|
||||||
|
let query = Attempt(srcPort: myPort, dstIp: peerIp, dstPorts: @[peerPort])
|
||||||
|
let i = puncher.attempts.find(query)
|
||||||
|
if i != -1:
|
||||||
|
puncher.attempts[i].future.complete(peerPort)
|
||||||
|
puncher.attempts.del(i)
|
||||||
|
|
||||||
|
proc handleMsg*(puncher: Puncher, msg: string,
|
||||||
|
peerAddr: SockAddr | Sockaddr_storage, peerSockLen: SockLen) =
|
||||||
|
var peerIp: IpAddress
|
||||||
|
var peerPort: Port
|
||||||
|
fromSockAddr(peerAddr, peerSockLen, peerIp, peerPort)
|
||||||
|
handleMsg(puncher, msg, peerIp, peerPort)
|
157
quicp2p.nim
157
quicp2p.nim
|
@ -4,11 +4,13 @@ import asyncdispatch
|
||||||
import asyncnet
|
import asyncnet
|
||||||
import base32
|
import base32
|
||||||
import certificate
|
import certificate
|
||||||
|
import message
|
||||||
import net
|
import net
|
||||||
import os
|
import os
|
||||||
import openssl_additional
|
import openssl_additional
|
||||||
import picotls/picotls
|
import picotls/picotls
|
||||||
import picotls/openssl as ptls_openssl
|
import picotls/openssl as ptls_openssl
|
||||||
|
import puncher
|
||||||
import quicly/quicly
|
import quicly/quicly
|
||||||
import quicly/cid
|
import quicly/cid
|
||||||
import quicly/constants
|
import quicly/constants
|
||||||
|
@ -16,6 +18,8 @@ import quicly/defaults
|
||||||
import quicly/recvstate
|
import quicly/recvstate
|
||||||
import quicly/sendstate
|
import quicly/sendstate
|
||||||
import quicly/streambuf
|
import quicly/streambuf
|
||||||
|
import random
|
||||||
|
import server_connection
|
||||||
import strformat
|
import strformat
|
||||||
import strutils
|
import strutils
|
||||||
|
|
||||||
|
@ -35,18 +39,15 @@ from openssl import
|
||||||
PSTACK,
|
PSTACK,
|
||||||
d2i_X509
|
d2i_X509
|
||||||
|
|
||||||
const serverCertChainPath = "./certs/server-certchain.pem"
|
|
||||||
const serverKeyPath = "./certs/server-cert.key"
|
|
||||||
const clientCertChainPath = "./certs/client-certchain.pem"
|
|
||||||
const clientKeyPath = "./certs/client-cert.key"
|
|
||||||
|
|
||||||
type
|
type
|
||||||
Connection = ref object
|
Connection = ref object
|
||||||
conn: ptr quicly_conn_t
|
conn: ptr quicly_conn_t
|
||||||
certs: seq[Certificate]
|
certs: seq[Certificate]
|
||||||
|
peerId: string
|
||||||
|
|
||||||
QuicP2PContext = ref object
|
QuicP2PContext = ref object
|
||||||
sock: AsyncSocket
|
sock: AsyncSocket
|
||||||
|
puncher: Puncher
|
||||||
streamOpen: quicly_stream_open_t
|
streamOpen: quicly_stream_open_t
|
||||||
nextCid: quicly_cid_plaintext_t
|
nextCid: quicly_cid_plaintext_t
|
||||||
signCertCb: ptls_openssl_sign_certificate_t
|
signCertCb: ptls_openssl_sign_certificate_t
|
||||||
|
@ -55,6 +56,16 @@ type
|
||||||
quiclyCtx: quicly_context_t
|
quiclyCtx: quicly_context_t
|
||||||
connections: seq[Connection]
|
connections: seq[Connection]
|
||||||
|
|
||||||
|
const serverCertChainPath = "./certs/server-certchain.pem"
|
||||||
|
const serverKeyPath = "./certs/server-cert.key"
|
||||||
|
const clientCertChainPath = "./certs/client-certchain.pem"
|
||||||
|
const clientKeyPath = "./certs/client-cert.key"
|
||||||
|
|
||||||
|
const rendezvousServers: seq[tuple[hostname: string, port: Port]] = @[
|
||||||
|
("strangeplace.net", Port(5320)),
|
||||||
|
("ulrich.earth", Port(5320))
|
||||||
|
]
|
||||||
|
|
||||||
proc getRelativeTimeout(ctx: QuicP2PContext): int32 =
|
proc getRelativeTimeout(ctx: QuicP2PContext): int32 =
|
||||||
## Obtain the absolute int64 timeout from quicly and convert it to the
|
## Obtain the absolute int64 timeout from quicly and convert it to the
|
||||||
## relative int32 timeout expected by poll.
|
## relative int32 timeout expected by poll.
|
||||||
|
@ -197,7 +208,8 @@ proc verifyCerts(self: ptr ptls_verify_certificate_t, tls: ptr ptls_t,
|
||||||
X509_STORE_free(store)
|
X509_STORE_free(store)
|
||||||
X509_free(caCert)
|
X509_free(caCert)
|
||||||
|
|
||||||
proc initContext(sock: AsyncSocket, certChainPath: string, keyPath: string,
|
proc initContext(sock: AsyncSocket, puncher: Puncher, certChainPath: string,
|
||||||
|
keyPath: string,
|
||||||
streamOpenCb: typeof(quicly_stream_open_t.cb)):
|
streamOpenCb: typeof(quicly_stream_open_t.cb)):
|
||||||
QuicP2PContext =
|
QuicP2PContext =
|
||||||
var tlsCtx = ptls_context_t(randomBytes: ptls_openssl_random_bytes,
|
var tlsCtx = ptls_context_t(randomBytes: ptls_openssl_random_bytes,
|
||||||
|
@ -205,7 +217,7 @@ proc initContext(sock: AsyncSocket, certChainPath: string, keyPath: string,
|
||||||
keyExchanges: ptls_openssl_key_exchanges,
|
keyExchanges: ptls_openssl_key_exchanges,
|
||||||
cipherSuites: ptls_openssl_cipher_suites)
|
cipherSuites: ptls_openssl_cipher_suites)
|
||||||
quicly_amend_ptls_context(addr tlsCtx)
|
quicly_amend_ptls_context(addr tlsCtx)
|
||||||
result = QuicP2PContext(sock: sock,
|
result = QuicP2PContext(sock: sock, puncher: puncher,
|
||||||
streamOpen: quicly_stream_open_t(cb: streamOpenCb),
|
streamOpen: quicly_stream_open_t(cb: streamOpenCb),
|
||||||
verifyCertsCb: ptls_verify_certificate_t(cb: verifyCerts),
|
verifyCertsCb: ptls_verify_certificate_t(cb: verifyCerts),
|
||||||
tlsCtx: tlsCtx, quiclyCtx: quicly_spec_context)
|
tlsCtx: tlsCtx, quiclyCtx: quicly_spec_context)
|
||||||
|
@ -223,10 +235,11 @@ proc initContext(sock: AsyncSocket, certChainPath: string, keyPath: string,
|
||||||
EVP_PKEY_free(privateKey)
|
EVP_PKEY_free(privateKey)
|
||||||
result.tlsCtx.sign_certificate = addr result.signCertCb.super
|
result.tlsCtx.sign_certificate = addr result.signCertCb.super
|
||||||
|
|
||||||
proc addConnection(ctx: QuicP2PContext, connPtr: ptr quicly_conn_t) =
|
proc addConnection(ctx: QuicP2PContext, connPtr: ptr quicly_conn_t,
|
||||||
|
peerId: string) =
|
||||||
assert(not connPtr.isNil)
|
assert(not connPtr.isNil)
|
||||||
let data = quicly_get_data(connPtr)
|
let data = quicly_get_data(connPtr)
|
||||||
var conn = Connection(conn: connPtr)
|
var conn = Connection(conn: connPtr, peerId: peerId)
|
||||||
data[] = addr conn[]
|
data[] = addr conn[]
|
||||||
ctx.connections.add(conn)
|
ctx.connections.add(conn)
|
||||||
|
|
||||||
|
@ -262,8 +275,25 @@ proc sendPackets(ctx: QuicP2PContext) =
|
||||||
else:
|
else:
|
||||||
raise newException(ValueError, &"quicly_send returned {sendResult}")
|
raise newException(ValueError, &"quicly_send returned {sendResult}")
|
||||||
|
|
||||||
proc handleMsg(ctx: QuicP2PContext, msg: string, peerAddr: ptr SockAddr,
|
proc initiateQuicConnection(ctx: QuicP2PContext, peerId: string,
|
||||||
isServer: bool) =
|
peerIp: IpAddress, peerPort: Port) =
|
||||||
|
var conn: ptr quicly_conn_t
|
||||||
|
var peerAddr: SockAddr_storage
|
||||||
|
var peerSockLen: SockLen
|
||||||
|
toSockAddr(peerIp, peerPort, peerAddr, peerSockLen)
|
||||||
|
let addressToken = ptls_iovec_init(nil, 0)
|
||||||
|
let connectResult = quicly_connect(addr conn, addr ctx.quiclyCtx,
|
||||||
|
peerId.cstring, addr peerAddr, nil,
|
||||||
|
addr ctx.nextCid, addressToken, nil, nil)
|
||||||
|
if connectResult != 0:
|
||||||
|
echo "quicly_connect failed: ", connectResult
|
||||||
|
return
|
||||||
|
var stream: ptr quicly_stream_t
|
||||||
|
discard quicly_open_stream(conn, addr stream, 0)
|
||||||
|
ctx.addConnection(conn, peerId)
|
||||||
|
|
||||||
|
proc handleMsg(ctx: QuicP2PContext, msg: string, peerId: string,
|
||||||
|
peerAddr: ptr Sockaddr_storage, peerSockLen: SockLen) =
|
||||||
var offset: csize_t = 0
|
var offset: csize_t = 0
|
||||||
while offset < msg.len().csize_t:
|
while offset < msg.len().csize_t:
|
||||||
var decoded: quicly_decoded_packet_t
|
var decoded: quicly_decoded_packet_t
|
||||||
|
@ -271,6 +301,10 @@ proc handleMsg(ctx: QuicP2PContext, msg: string, peerAddr: ptr SockAddr,
|
||||||
cast[ptr uint8](msg.cstring),
|
cast[ptr uint8](msg.cstring),
|
||||||
msg.len().csize_t, addr offset)
|
msg.len().csize_t, addr offset)
|
||||||
if decode_result == csize_t.high:
|
if decode_result == csize_t.high:
|
||||||
|
# The puncher needs to be informed about this message because quicly not
|
||||||
|
# being able to decode it may indicate it's the peer's response to our
|
||||||
|
# initiate call.
|
||||||
|
ctx.puncher.handleMsg(msg, peerAddr[], peerSockLen)
|
||||||
return
|
return
|
||||||
var conn: ptr quicly_conn_t = nil
|
var conn: ptr quicly_conn_t = nil
|
||||||
for c in ctx.connections:
|
for c in ctx.connections:
|
||||||
|
@ -279,12 +313,16 @@ proc handleMsg(ctx: QuicP2PContext, msg: string, peerAddr: ptr SockAddr,
|
||||||
break
|
break
|
||||||
if conn != nil:
|
if conn != nil:
|
||||||
discard quicly_receive(conn, nil, peerAddr, addr decoded)
|
discard quicly_receive(conn, nil, peerAddr, addr decoded)
|
||||||
elif isServer:
|
elif peerId.len != 0:
|
||||||
|
# The puncher needs to be informed about this message because it may
|
||||||
|
# be the peer's response to our respond call. Quicly needs to be informed
|
||||||
|
# because we except the first QUIC handshake packet in it.
|
||||||
|
ctx.puncher.handleMsg(msg, peerAddr[], peerSockLen)
|
||||||
discard quicly_accept(addr conn, addr ctx.quiclyCtx, nil, peerAddr,
|
discard quicly_accept(addr conn, addr ctx.quiclyCtx, nil, peerAddr,
|
||||||
addr decoded, nil, addr ctx.nextCid, nil)
|
addr decoded, nil, addr ctx.nextCid, nil)
|
||||||
ctx.addConnection(conn)
|
ctx.addConnection(conn, peerId)
|
||||||
|
|
||||||
proc receive(ctx: QuicP2PContext, isServer: bool) {.async.} =
|
proc receive(ctx: QuicP2PContext, peerId: string) {.async.} =
|
||||||
while true:
|
while true:
|
||||||
# TODO: replace asyncdispatch.recvFromInto with asyncnet.recvFrom (Nim 1.4 required)
|
# TODO: replace asyncdispatch.recvFromInto with asyncnet.recvFrom (Nim 1.4 required)
|
||||||
var msg = newString(BufferSize)
|
var msg = newString(BufferSize)
|
||||||
|
@ -295,52 +333,71 @@ proc receive(ctx: QuicP2PContext, isServer: bool) {.async.} =
|
||||||
addr peerAddrLen)
|
addr peerAddrLen)
|
||||||
msg.setLen(msgLen)
|
msg.setLen(msgLen)
|
||||||
if msg.len > 0:
|
if msg.len > 0:
|
||||||
handleMsg(ctx, msg, cast[ptr SockAddr](addr peerAddr), isServer)
|
handleMsg(ctx, msg, peerId, addr peerAddr, peerAddrLen)
|
||||||
|
|
||||||
|
proc handleNotification(ctx: QuicP2PContext, notification: NotifyPeer)
|
||||||
|
{.async.} =
|
||||||
|
let _ = await ctx.puncher.respond(notification.dstIp, notification.dstPort,
|
||||||
|
notification.probedDstPorts)
|
||||||
|
|
||||||
|
proc runApp(ctx: QuicP2PContext, srcPort: Port, peerId: string) {.async.} =
|
||||||
|
let serverConn = await initServerConnection(rendezvousServers[0].hostname,
|
||||||
|
rendezvousServers[0].port,
|
||||||
|
srcPort, rendezvousServers)
|
||||||
|
asyncCheck handleServerMessages(serverConn)
|
||||||
|
asyncCheck receive(ctx, peerId)
|
||||||
|
|
||||||
|
if peerId.len == 0:
|
||||||
|
# We are the responder
|
||||||
|
let probedPorts = serverConn.probedSrcPorts.join(",")
|
||||||
|
let req = &"{ctx.getPeerId()}|{serverConn.probedIp}|{srcPort}|{probedPorts}"
|
||||||
|
discard await serverConn.sendRequest("register", req)
|
||||||
|
while true:
|
||||||
|
let (hasData, data) = await serverConn.peerNotifications.read()
|
||||||
|
if not hasData:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
let msg = parseMessage[NotifyPeer](data)
|
||||||
|
# FIXME: check if we want to receive messages from the sender
|
||||||
|
echo "received message from ", msg.sender
|
||||||
|
asyncCheck handleNotification(ctx, msg)
|
||||||
|
except ValueError as e:
|
||||||
|
echo e.msg
|
||||||
|
discard
|
||||||
|
|
||||||
|
else:
|
||||||
|
# We are the initiator
|
||||||
|
let serverResponse = await serverConn.sendRequest("get-peerinfo", peerId)
|
||||||
|
let peerInfo = parseMessage[OkGetPeerInfo](serverResponse)
|
||||||
|
let myProbedPorts = serverConn.probedSrcPorts.join(",")
|
||||||
|
let peerProbedPorts = peerInfo.probedPorts.join(",")
|
||||||
|
let req = &"{ctx.getPeerId()}|{peerId}|{serverConn.probedIp}|{srcPort}|{myProbedPorts}|{peerInfo.ip}|{peerInfo.localPort}|{peerProbedPorts}"
|
||||||
|
discard await serverConn.sendRequest("notify-peer", req)
|
||||||
|
let peerPort = await ctx.puncher.initiate(peerInfo.ip, peerInfo.localPort,
|
||||||
|
peerInfo.probedPorts)
|
||||||
|
initiateQuicConnection(ctx, peerId, peerInfo.ip, peerPort)
|
||||||
|
|
||||||
proc main() =
|
proc main() =
|
||||||
var ctx: QuicP2PContext
|
var ctx: QuicP2PContext
|
||||||
let sock = newAsyncSocket(sockType = SOCK_DGRAM, protocol = IPPROTO_UDP,
|
let sock = newAsyncSocket(sockType = SOCK_DGRAM, protocol = IPPROTO_UDP,
|
||||||
buffered = false)
|
buffered = false)
|
||||||
|
randomize()
|
||||||
|
let srcPort = rand(Port(1024) .. Port.high)
|
||||||
|
sock.bindAddr(srcPort)
|
||||||
|
let puncher = initPuncher(sock)
|
||||||
|
|
||||||
case paramCount():
|
case paramCount():
|
||||||
of 1:
|
of 0:
|
||||||
let portNumber = paramStr(1).parseUInt()
|
ctx = initContext(sock, puncher, serverCertChainPath, serverKeyPath,
|
||||||
if portNumber > uint16.high:
|
|
||||||
usage()
|
|
||||||
quit(1)
|
|
||||||
sock.bindAddr(Port(portNumber))
|
|
||||||
ctx = initContext(sock, serverCertChainPath, serverKeyPath,
|
|
||||||
onServerStreamOpen)
|
onServerStreamOpen)
|
||||||
ctx.tlsCtx.require_client_authentication = 1
|
ctx.tlsCtx.require_client_authentication = 1
|
||||||
asyncCheck receive(ctx, true)
|
asyncCheck runApp(ctx, srcPort, "")
|
||||||
|
|
||||||
of 2:
|
of 1:
|
||||||
let hostname = paramStr(1)
|
let peerId = paramStr(1)
|
||||||
let portNumber = paramStr(2).parseUInt()
|
ctx = initContext(sock, puncher, clientCertChainPath, clientKeyPath,
|
||||||
if portNumber > uint16.high:
|
|
||||||
usage()
|
|
||||||
quit(1)
|
|
||||||
ctx = initContext(sock, clientCertChainPath, clientKeyPath,
|
|
||||||
onClientStreamOpen)
|
onClientStreamOpen)
|
||||||
var conn: ptr quicly_conn_t
|
asyncCheck runApp(ctx, srcPort, peerId)
|
||||||
let hostent = getHostByName(hostname)
|
|
||||||
if hostent.addrList.len == 0:
|
|
||||||
echo "cannot resolve hostname ", hostname
|
|
||||||
quit(2)
|
|
||||||
var destAddr: Sockaddr_storage
|
|
||||||
var sockLen: SockLen
|
|
||||||
toSockAddr(parseIpAddress(hostent.addrList[0]), Port(portNumber), destAddr,
|
|
||||||
sockLen)
|
|
||||||
let addressToken = ptls_iovec_init(nil, 0)
|
|
||||||
let connectResult = quicly_connect(addr conn, addr ctx.quiclyCtx,
|
|
||||||
hostname.cstring, addr destAddr, nil,
|
|
||||||
addr ctx.nextCid, addressToken, nil, nil)
|
|
||||||
if connectResult != 0:
|
|
||||||
echo "quicly_connect failed: ", connectResult
|
|
||||||
quit(3)
|
|
||||||
var stream: ptr quicly_stream_t
|
|
||||||
discard quicly_open_stream(conn, addr stream, 0)
|
|
||||||
ctx.addConnection(conn)
|
|
||||||
asyncCheck receive(ctx, false)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
usage()
|
usage()
|
||||||
|
|
|
@ -0,0 +1,104 @@
|
||||||
|
import asyncdispatch, asyncnet, message, net, tables, random, strformat
|
||||||
|
|
||||||
|
type
|
||||||
|
Endpoint* = tuple[hostname: string, port: Port]
|
||||||
|
|
||||||
|
ServerConnection* = ref object
|
||||||
|
sock: AsyncSocket
|
||||||
|
outMessages: TableRef[string, Future[string]]
|
||||||
|
peerNotifications*: FutureStream[string]
|
||||||
|
probedIp*: IpAddress
|
||||||
|
srcPort*: Port
|
||||||
|
probedSrcPorts*: seq[Port]
|
||||||
|
|
||||||
|
ServerError* = object of ValueError
|
||||||
|
|
||||||
|
OkGetPeerinfo* = object
|
||||||
|
ip*: IpAddress
|
||||||
|
localPort*: Port
|
||||||
|
probedPorts*: seq[Port]
|
||||||
|
|
||||||
|
OkGetEndpoint* = object
|
||||||
|
ip*: IpAddress
|
||||||
|
port*: Port
|
||||||
|
|
||||||
|
NotifyPeer* = object
|
||||||
|
sender*: string
|
||||||
|
recipient*: string
|
||||||
|
technique*: string
|
||||||
|
srcIp*: IpAddress
|
||||||
|
srcPort*: Port
|
||||||
|
probedSrcPorts*: seq[Port]
|
||||||
|
dstIp*: IpAddress
|
||||||
|
dstPort*: Port
|
||||||
|
probedDstPorts*: seq[Port]
|
||||||
|
extraArgs*: string
|
||||||
|
|
||||||
|
proc getEndpoint(srcPort: Port, serverHostname: string, serverPort: Port):
|
||||||
|
Future[OkGetEndpoint] {.async.} =
|
||||||
|
let sock = newAsyncSocket()
|
||||||
|
var failCount = 0
|
||||||
|
while true:
|
||||||
|
try:
|
||||||
|
sock.bindAddr(srcPort)
|
||||||
|
break
|
||||||
|
except OSError as e:
|
||||||
|
if failCount == 3:
|
||||||
|
raise e
|
||||||
|
failCount.inc
|
||||||
|
await sleepAsync(100)
|
||||||
|
await sock.connect(serverHostname, serverPort)
|
||||||
|
let id = rand(uint32)
|
||||||
|
await sock.send(&"get-endpoint|{id}\n")
|
||||||
|
let line = await sock.recvLine(maxLength = 400)
|
||||||
|
let args = line.parseArgs(3)
|
||||||
|
assert(args[0] == "ok")
|
||||||
|
assert(args[1] == $id)
|
||||||
|
result = parseMessage[OkGetEndpoint](args[2])
|
||||||
|
let emptyLine = await sock.recvLine(maxLength = 400)
|
||||||
|
assert(emptyLine.len == 0)
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
proc initServerConnection*(serverHostname: string, serverPort: Port,
|
||||||
|
srcPort: Port, probingServers: seq[Endpoint]):
|
||||||
|
Future[ServerConnection] {.async.} =
|
||||||
|
result.srcPort = srcPort
|
||||||
|
for s in probingServers:
|
||||||
|
let endpoint = await getEndpoint(srcPort, s.hostname, s.port)
|
||||||
|
# FIXME: what if we get get different IPs from different servers
|
||||||
|
result.probedIp = endpoint.ip
|
||||||
|
result.probedSrcPorts.add(endpoint.port)
|
||||||
|
result.sock = await asyncnet.dial(serverHostname,
|
||||||
|
serverPort)
|
||||||
|
result.outMessages = newTable[string, Future[string]]()
|
||||||
|
result.peerNotifications = newFutureStream[string]("initServerConnection")
|
||||||
|
|
||||||
|
proc handleServerMessages*(conn: ServerConnection) {.async.} =
|
||||||
|
while true:
|
||||||
|
let line = await conn.sock.recvLine(maxLength = 400)
|
||||||
|
let args = line.parseArgs(3, 1)
|
||||||
|
case args[0]:
|
||||||
|
of "ok":
|
||||||
|
let future = conn.outMessages[args[1]]
|
||||||
|
conn.outMessages.del(args[1])
|
||||||
|
future.complete(args[2])
|
||||||
|
of "error":
|
||||||
|
let future = conn.outMessages[args[1]]
|
||||||
|
conn.outMessages.del(args[1])
|
||||||
|
future.fail(newException(ServerError, args[2]))
|
||||||
|
of "notify-peer":
|
||||||
|
asyncCheck conn.peerNotifications.write(line.substr(args[0].len + 1))
|
||||||
|
else:
|
||||||
|
raise newException(ValueError, "invalid server message")
|
||||||
|
|
||||||
|
proc sendRequest*(connection: ServerConnection, command: string,
|
||||||
|
content: string): Future[string] =
|
||||||
|
result = newFuture[string]("sendRequest")
|
||||||
|
let id = $rand(uint32)
|
||||||
|
var request: string
|
||||||
|
if content.len != 0:
|
||||||
|
request = &"{command}|{id}|{content}\n"
|
||||||
|
else:
|
||||||
|
request = &"{command}|{id}\n"
|
||||||
|
asyncCheck connection.sock.send(request)
|
||||||
|
connection.outMessages[id] = result
|
Loading…
Reference in New Issue