change puncher interface to allow accepting multiple connections

This commit is contained in:
Christian Ulrich 2020-10-02 17:12:29 +02:00
parent dce5115c5c
commit 84cb8611ef
No known key found for this signature in database
GPG Key ID: 8241BE099775A097
2 changed files with 177 additions and 102 deletions

View File

@ -13,7 +13,7 @@ const PunchdSocket = "/tmp/punchd.socket"
type type
Punchd = ref object Punchd = ref object
unixSocket: AsyncSocket unixSocket: AsyncSocket
punchers: seq[TcpSyniPuncher] tcpSyniPuncher: TcpSyniPuncher
Sigint = object of CatchableError Sigint = object of CatchableError
@ -52,16 +52,13 @@ proc handleRequest(punchd: Punchd, line: string,
$req.dstIp, req.dstPorts.join(","), $req.dstIp, req.dstPorts.join(","),
seqNumbers.join(",")].join("|") seqNumbers.join(",")].join("|")
await unixSock.send(&"progress|{id}|{content}\n") await unixSock.send(&"progress|{id}|{content}\n")
puncher = initPuncher(req.srcPorts[0], req.dstIp, req.dstPorts) sock = await punchd.tcpSyniPuncher.connect(req.srcPorts[0], req.dstIp,
punchd.punchers.add(puncher) req.dstPorts, handleSeqNumbers)
sock = await puncher.connect(handleSeqNumbers)
of "tcp-syni-accept": of "tcp-syni-accept":
let req = parseMessage[TcpSyniAccept](args[2]) let req = parseMessage[TcpSyniAccept](args[2])
puncher = initPuncher(req.srcPorts[0], req.dstIp, req.dstPorts, sock = await punchd.tcpSyniPuncher.accept(req.srcPorts[0], req.dstIp,
req.seqNums) req.dstPorts, req.seqNums)
punchd.punchers.add(puncher)
sock = await puncher.accept()
else: else:
raise newException(ValueError, "invalid request") raise newException(ValueError, "invalid request")
@ -91,19 +88,19 @@ proc handleUsers(punchd: Punchd) {.async.} =
proc main() = proc main() =
setControlCHook(handleSigint) setControlCHook(handleSigint)
removeFile(PunchdSocket) removeFile(PunchdSocket)
let punchd = Punchd(unixSocket: newAsyncSocket(AF_UNIX, SOCK_STREAM, IPPROTO_IP)) let unixSocket = newAsyncSocket(AF_UNIX, SOCK_STREAM, IPPROTO_IP)
punchd.unixSocket.bindUnix(PunchdSocket) unixSocket.bindUnix(PunchdSocket)
unixSocket.listen()
setFilePermissions(PunchdSocket, setFilePermissions(PunchdSocket,
{fpUserRead, fpUserWrite, fpGroupRead, fpGroupWrite, {fpUserRead, fpUserWrite, fpGroupRead, fpGroupWrite,
fpOthersRead, fpOthersWrite}) fpOthersRead, fpOthersWrite})
punchd.unixSocket.listen() let punchd = Punchd(unixSocket: unixSocket, tcpSyniPuncher: initPuncher())
asyncCheck handleUsers(punchd) asyncCheck handleUsers(punchd)
try: try:
runForever() runForever()
except Sigint: except Sigint:
for puncher in punchd.punchers: waitFor punchd.tcpSyniPuncher.cleanup()
waitFor puncher.cleanup()
punchd.unixSocket.close() punchd.unixSocket.close()
removeFile(PunchdSocket) removeFile(PunchdSocket)
quit(0) quit(0)

View File

@ -1,6 +1,7 @@
import asyncfutures, asyncdispatch, asyncnet, strformat import asyncfutures, asyncdispatch, asyncnet, strformat
from net import IpAddress, Port, `$`, `==`, getPrimaryIPAddr, toSockAddr from net import IpAddress, Port, `$`, `==`, getPrimaryIPAddr, toSockAddr, parseIpAddress
from nativesockets import SockAddr, Sockaddr_storage, SockLen, setSockOptInt from nativesockets import SockAddr, Sockaddr_storage, SockLen, setSockOptInt
from sequtils import any
import asyncutils import asyncutils
import ip_packet import ip_packet
import network_interface import network_interface
@ -12,13 +13,25 @@ var IP_TTL {.importc: "IP_TTL", header: "<netinet/in.h>".}: cint
const Timeout = 3000 const Timeout = 3000
type type
TcpSyniPuncher* = ref object ConnectAttempt* = ref object
srcIp: IpAddress
srcPort: Port
dstIp: IpAddress
dstPorts: seq[Port]
firewallRules: seq[string]
AcceptAttempt* = ref object
srcIp: IpAddress srcIp: IpAddress
srcPort: Port srcPort: Port
dstIp: IpAddress dstIp: IpAddress
dstPorts: seq[Port] dstPorts: seq[Port]
seqNums: seq[uint32] seqNums: seq[uint32]
firewallRules: seq[string] firewallRules: seq[string]
future: Future[AsyncSocket]
TcpSyniPuncher* = ref object
connectAttempts: seq[ConnectAttempt]
acceptAttempts: seq[AcceptAttempt]
PunchProgressCb* = proc (seqNums: seq[uint32]) {.async.} PunchProgressCb* = proc (seqNums: seq[uint32]) {.async.}
@ -47,6 +60,26 @@ proc iptablesDelete(chain: string, rule: string) {.async.} =
let firewall_cmd = fmt"iptables -D {chain} {rule}" let firewall_cmd = fmt"iptables -D {chain} {rule}"
discard await asyncExecCmd(firewall_cmd) discard await asyncExecCmd(firewall_cmd)
proc addFirewallRules[T](attempt: T) {.async.} =
for dstPort in attempt.dstPorts:
let rule = makeFirewallRule(attempt.srcIp, attempt.srcPort,
attempt.dstIp, dstPort)
try:
await iptablesInsert("INPUT", rule)
attempt.firewallRules.add(rule)
except OSError as e:
echo "cannot add firewall rule: ", e.msg
raise newException(PunchHoleError, e.msg)
proc deleteFirewallRules[T](attempt: T) {.async.} =
for rule in attempt.firewallRules:
# FIXME: close sock?
try:
await iptablesDelete("INPUT", rule)
except OSError:
# At least we tried
discard
proc injectTcpPacket(rawFd: AsyncFD, ipPacket: IpPacket) {.async.} = proc injectTcpPacket(rawFd: AsyncFD, ipPacket: IpPacket) {.async.} =
assert(ipPacket.protocol == tcp) assert(ipPacket.protocol == tcp)
try: try:
@ -59,27 +92,27 @@ proc injectTcpPacket(rawFd: AsyncFD, ipPacket: IpPacket) {.async.} =
except OSError as e: except OSError as e:
raise newException(PunchHoleError, e.msg) raise newException(PunchHoleError, e.msg)
proc captureSeqNumbers(puncher: TcpSyniPuncher, rawFd: AsyncFD, proc captureSeqNumbers(attempt: ConnectAttempt, rawFd: AsyncFD,
cb: PunchProgressCb) {.async.} = cb: PunchProgressCb) {.async.} =
# FIXME: timeout? # FIXME: timeout?
var seqNums = newSeq[uint32]() var seqNums = newSeq[uint32]()
while seqNums.len < puncher.dstPorts.len: while seqNums.len < attempt.dstPorts.len:
let packet = await rawFd.recv(4000) let packet = await rawFd.recv(4000)
if packet == "": if packet == "":
break break
let parsed = parseEthernetPacket(packet) let parsed = parseEthernetPacket(packet)
if parsed.protocol == tcp and if parsed.protocol == tcp and
parsed.ipAddrSrc == puncher.srcIp and parsed.ipAddrSrc == attempt.srcIp and
parsed.tcpPortSrc.int == puncher.srcPort.int and parsed.tcpPortSrc.int == attempt.srcPort.int and
parsed.ipAddrDst == puncher.dstIp and parsed.ipAddrDst == attempt.dstIp and
parsed.tcpFlags == {SYN}: parsed.tcpFlags == {SYN}:
for port in puncher.dstPorts: for port in attempt.dstPorts:
if parsed.tcpPortDst.int == port.int: if parsed.tcpPortDst.int == port.int:
seqNums.add(parsed.tcpSeqNumber) seqNums.add(parsed.tcpSeqNumber)
break break
await cb(seqNums) await cb(seqNums)
proc captureAndResendAck(puncher: TcpSyniPuncher, captureFd: AsyncFD, proc captureAndResendAck(attempt: ConnectAttempt, captureFd: AsyncFD,
injectFd: AsyncFD) {.async.} = injectFd: AsyncFD) {.async.} =
while true: while true:
let packet = await captureFd.recv(4000) let packet = await captureFd.recv(4000)
@ -87,47 +120,56 @@ proc captureAndResendAck(puncher: TcpSyniPuncher, captureFd: AsyncFD,
break break
var parsed = parseEthernetPacket(packet) var parsed = parseEthernetPacket(packet)
if parsed.protocol == tcp and if parsed.protocol == tcp and
parsed.ipAddrSrc == puncher.srcIp and parsed.ipAddrSrc == attempt.srcIp and
parsed.tcpPortSrc.int == puncher.srcPort.int and parsed.tcpPortSrc.int == attempt.srcPort.int and
parsed.ipAddrDst == puncher.dstIp and parsed.ipAddrDst == attempt.dstIp and
parsed.tcpFlags == {ACK}: parsed.tcpFlags == {ACK}:
for port in puncher.dstPorts: for port in attempt.dstPorts:
if parsed.tcpPortDst.int == port.int: if parsed.tcpPortDst.int == port.int:
parsed.ipTTL = 64 parsed.ipTTL = 64
echo &"[{parsed.ipAddrSrc}:{parsed.tcpPortSrc.int} -> {parsed.ipAddrDst}:{parsed.tcpPortDst}, SEQ {parsed.tcpSeqNumber}] resending ACK with TTL {parsed.ipTTL}" echo &"[{parsed.ipAddrSrc}:{parsed.tcpPortSrc.int} -> {parsed.ipAddrDst}:{parsed.tcpPortDst}, SEQ {parsed.tcpSeqNumber}] resending ACK with TTL {parsed.ipTTL}"
await injectFd.injectTcpPacket(parsed) await injectFd.injectTcpPacket(parsed)
return return
proc initPuncher*(srcPort: Port, dstIp: IpAddress, dstPorts: seq[Port], proc initPuncher*(): TcpSyniPuncher = TcpSyniPuncher()
seqNums: seq[uint32] = @[]): TcpSyniPuncher =
let localIp = getPrimaryIPAddr(dstIp)
# TODO: do real port prediction
var predictedDstPorts = newSeq[Port](1)
let basePort = min(dstPorts[1].uint16,
uint16.high - (predictedDstPorts.len - 1).uint16)
for i in 0 .. predictedDstPorts.len - 1:
predictedDstPorts[i] = Port(basePort + i.uint16)
result = TcpSyniPuncher(srcIp: localIp, srcPort: srcPort, dstIp: dstIp,
dstPorts: predictedDstPorts, seqNums: seqNums)
proc addFirewallRules(puncher: TcpSyniPuncher) {.async.} = proc findConnectAttempt(puncher: TcpSyniPuncher, srcIp: IpAddress,
for dstPort in puncher.dstPorts: srcPort: Port, dstIp: IpAddress,
let rule = makeFirewallRule(puncher.srcIp, puncher.srcPort, dstPorts: seq[Port]): int =
puncher.dstIp, dstPort) for (index, attempt) in puncher.connectAttempts.pairs():
try: if attempt.srcIp == srcIp and attempt.srcPort == srcPort and
await iptablesInsert("INPUT", rule) attempt.dstIp == dstIp and
puncher.firewallRules.add(rule) attempt.dstPorts.any(proc (p: Port): bool = p in dstPorts):
except OSError as e: return index
echo "cannot add firewall rule: ", e.msg
raise newException(PunchHoleError, e.msg) proc findAcceptAttempt(puncher: TcpSyniPuncher, srcIp: IpAddress,
srcPort: Port, dstIp: IpAddress,
dstPorts: seq[Port]): int =
for (index, attempt) in puncher.acceptAttempts.pairs():
if attempt.srcIp == srcIp and attempt.srcPort == srcPort and
attempt.dstIp == dstIp and
attempt.dstPorts.any(proc (p: Port): bool = p in dstPorts):
return index
proc findAcceptAttemptsByLocalAddr(puncher: TcpSyniPuncher, address: IpAddress,
port: Port): seq[AcceptAttempt] =
for attempt in puncher.acceptAttempts:
if attempt.srcIp == address and attempt.srcPort == port:
result.add(attempt)
proc predictPortRange(dstPorts: seq[Port]): seq[Port] =
# TODO: do real port prediction
result = newSeq[Port](1)
let basePort = min(dstPorts[1].uint16,
uint16.high - (result.len - 1).uint16)
for i in 0 .. result.len - 1:
result[i] = Port(basePort + i.uint16)
proc cleanup*(puncher: TcpSyniPuncher) {.async.} = proc cleanup*(puncher: TcpSyniPuncher) {.async.} =
for rule in puncher.firewallRules: for attempt in puncher.connectAttempts:
try: await attempt.deleteFirewallRules()
await iptablesDelete("INPUT", rule) for attempt in puncher.acceptAttempts:
except OSError: await attempt.deleteFirewallRules()
# At least we tried
discard
proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port, proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port,
future: Future[AsyncSocket]) {.async.} = future: Future[AsyncSocket]) {.async.} =
@ -144,39 +186,45 @@ proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port,
echo &"connection {srcIP}:{srcPort.int} -> {dstIp}:{dstPort.int} failed: ", e.msg echo &"connection {srcIP}:{srcPort.int} -> {dstIp}:{dstPort.int} failed: ", e.msg
discard discard
proc connectParallel(puncher: TcpSyniPuncher): Future[AsyncSocket] = proc connect*(puncher: TcpSyniPuncher, srcPort: Port, dstIp: IpAddress,
result = newFuture[AsyncSocket]("doConnect") dstPorts: seq[Port],
for dstPort in puncher.dstPorts:
asyncCheck doConnect(puncher.srcIp, puncher.srcPort, puncher.dstIp, dstPort, result)
proc connect*(puncher: TcpSyniPuncher,
progressCb: PunchProgressCb): Future[AsyncSocket] {.async.} = progressCb: PunchProgressCb): Future[AsyncSocket] {.async.} =
let iface = fromIpAddress(puncher.srcIp) let localIp = getPrimaryIPAddr(dstIp)
if puncher.findConnectAttempt(localIp, srcPort, dstIp, dstPorts) != -1:
raise newException(PunchHoleError, "hole punching for given parameters already active")
let attempt = ConnectAttempt(srcIp: localIp, srcPort: srcPort, dstIp: dstIp,
dstPorts: predictPortRange(dstPorts))
puncher.connectAttempts.add(attempt)
await attempt.addFirewallRules()
let iface = fromIpAddress(attempt.srcIp)
let captureSeqFd = setupEthernetCapturingSocket(iface) let captureSeqFd = setupEthernetCapturingSocket(iface)
let captureAckFd = setupEthernetCapturingSocket(iface) let captureAckFd = setupEthernetCapturingSocket(iface)
let injectAckFd = setupTcpInjectingSocket() let injectAckFd = setupTcpInjectingSocket()
asyncCheck puncher.captureSeqNumbers(captureSeqFd, progressCb) asyncCheck attempt.captureSeqNumbers(captureSeqFd, progressCb)
asyncCheck puncher.captureAndResendAck(captureAckFd, injectAckFd) asyncCheck attempt.captureAndResendAck(captureAckFd, injectAckFd)
await puncher.addFirewallRules()
try: try:
let connectParallelFuture = puncher.connectParallel() let connectFuture = newFuture[AsyncSocket]("connect")
await connectParallelFuture or sleepAsync(Timeout) for dstPort in attempt.dstPorts:
await puncher.cleanup() asyncCheck doConnect(attempt.srcIp, attempt.srcPort, attempt.dstIp,
if connectParallelFuture.finished(): dstPort, connectfuture)
result = connectParallelFuture.read() await connectFuture or sleepAsync(Timeout)
await attempt.deleteFirewallRules()
puncher.connectAttempts.del(puncher.connectAttempts.find(attempt))
if connectFuture.finished():
result = connectFuture.read()
else: else:
raise newException(PunchHoleError, "timeout") raise newException(PunchHoleError, "timeout")
except OSError as e: except OSError as e:
raise newException(PunchHoleError, e.msg) raise newException(PunchHoleError, e.msg)
proc prepareAccept(puncher: TcpSyniPuncher) {.async.} = proc prepareAccept(attempt: AcceptAttempt) {.async.} =
for dstPort in puncher.dstPorts: for dstPort in attempt.dstPorts:
try: try:
let sock = newAsyncSocket() let sock = newAsyncSocket()
sock.setSockOpt(OptReuseAddr, true) sock.setSockOpt(OptReuseAddr, true)
sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2) sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2)
sock.bindAddr(puncher.srcPort, $(puncher.srcIp)) sock.bindAddr(attempt.srcPort, $(attempt.srcIp))
let connectFuture = sock.connect($(puncher.dstIp), dstPort) let connectFuture = sock.connect($(attempt.dstIp), dstPort)
await connectFuture or sleepAsync(Timeout) await connectFuture or sleepAsync(Timeout)
if connectFuture.finished(): if connectFuture.finished():
echo "connected during accept phase" echo "connected during accept phase"
@ -184,38 +232,68 @@ proc prepareAccept(puncher: TcpSyniPuncher) {.async.} =
except OSError: except OSError:
discard discard
proc accept*(puncher: TcpSyniPuncher): Future[AsyncSocket] {.async.} = proc doAccept(puncher: TcpSyniPuncher, srcIp: IpAddress,
await puncher.prepareAccept() srcPort: Port) {.async.} =
await puncher.addFirewallRules()
try:
let rawFd = setupTcpInjectingSocket()
for dstPort in puncher.dstPorts:
for seqNum in puncher.seqNums:
let ipPacket = IpPacket(protocol: tcp,
ipAddrSrc: puncher.dstIp,
ipAddrDst: puncher.srcIp,
ipTTL: 64,
tcpPortSrc: dstPort,
tcpPortDst: puncher.srcPort,
tcpSeqNumber: seqNum,
tcpAckNumber: 0,
tcpFlags: {SYN},
tcpWindowSize: 1452 * 10)
echo &"[{ipPacket.ipAddrSrc}:{ipPacket.tcpPortSrc} -> {ipPacket.ipAddrDst}:{ipPacket.tcpPortDst}, SEQ {ipPacket.tcpSeqNumber}] injecting SYN"
asyncCheck rawFd.injectTcpPacket(ipPacket)
let sock = newAsyncSocket() let sock = newAsyncSocket()
sock.setSockOpt(OptReuseAddr, true) sock.setSockOpt(OptReuseAddr, true)
sock.bindAddr(puncher.srcPort, $(puncher.srcIp)) sock.bindAddr(srcPort, $(srcIp))
sock.listen() sock.listen()
echo &"accepting connections from {puncher.dstIp}:{puncher.dstPorts[0].int}" while true:
let acceptFuture = sock.accept() let acceptFuture = sock.accept()
await acceptFuture or sleepAsync(Timeout) await acceptFuture or sleepAsync(Timeout)
await puncher.cleanup()
if acceptFuture.finished(): if acceptFuture.finished():
result = acceptFuture.read() let peer = acceptFuture.read()
let (peerAddr, peerPort) = peer.getPeerAddr()
let peerIp = parseIpAddress(peerAddr)
let i = puncher.findAcceptAttempt(srcIp, srcPort, peerIp, @[peerPort])
if i == -1:
echo "Accepted connection, but no attempt found. Discarding."
else:
let attempt = puncher.acceptAttempts[i]
attempt.future.complete(peer)
else:
let attempts = puncher.findAcceptAttemptsByLocalAddr(srcIp, srcPort)
# FIXME: should attempts have timestamps, so we can decide here which ones to delete?
if attempts.len() <= 1:
break
proc accept*(puncher: TcpSyniPuncher, srcPort: Port, dstIp: IpAddress,
dstPorts: seq[Port],
seqNums: seq[uint32]): Future[AsyncSocket] {.async.} =
let localIp = getPrimaryIPAddr(dstIp)
let existingAttempts = puncher.findAcceptAttemptsByLocalAddr(localIp, srcPort)
if existingAttempts.len() == 0:
echo &"accepting connections from {dstIp}:{dstPorts[0].int}"
asyncCheck puncher.doAccept(localIp, srcPort)
else:
for a in existingAttempts:
if a.dstIp == dstIp and
a.dstPorts.any(proc (p: Port): bool = p in dstPorts):
raise newException(PunchHoleError, "hole punching for given parameters already active")
let attempt = AcceptAttempt(srcIp: localIp, srcPort: srcPort, dstIp: dstIp,
dstPorts: dstPorts, seqNums: seqNums,
future: newFuture[AsyncSocket]("accept"))
puncher.acceptAttempts.add(attempt)
await attempt.addFirewallRules()
await attempt.prepareAccept()
try:
let rawFd = setupTcpInjectingSocket()
for dstPort in attempt.dstPorts:
for seqNum in attempt.seqNums:
let ipPacket = IpPacket(protocol: tcp, ipAddrSrc: attempt.dstIp,
ipAddrDst: attempt.srcIp, ipTTL: 64,
tcpPortSrc: dstPort,
tcpPortDst: attempt.srcPort,
tcpSeqNumber: seqNum, tcpAckNumber: 0,
tcpFlags: {SYN}, tcpWindowSize: 1452 * 10)
echo &"[{ipPacket.ipAddrSrc}:{ipPacket.tcpPortSrc} -> {ipPacket.ipAddrDst}:{ipPacket.tcpPortDst}, SEQ {ipPacket.tcpSeqNumber}] injecting SYN"
asyncCheck rawFd.injectTcpPacket(ipPacket)
await attempt.future or sleepAsync(Timeout)
await attempt.deleteFirewallRules()
puncher.acceptAttempts.del(puncher.acceptAttempts.find(attempt))
if attempt.future.finished():
result = attempt.future.read()
else: else:
raise newException(PunchHoleError, "timeout") raise newException(PunchHoleError, "timeout")
except OSError as e: except OSError as e:
echo &"accepting connections from {puncher.dstIP}:{puncher.dstPorts[0].int} failed: ", e.msg echo &"accepting connections from {dstIP}:{dstPorts[0].int} failed: ", e.msg
await puncher.cleanup()
raise newException(PunchHoleError, e.msg)