punchd/tcp_syni.nim

307 lines
12 KiB
Nim

import asyncfutures, asyncdispatch, asyncnet, strformat
from net import IpAddress, Port, `$`, `==`, toSockAddr, parseIpAddress
from nativesockets import SockAddr, Sockaddr_storage, SockLen, setSockOptInt
from random import randomize, rand
from sequtils import any
import asyncutils
import ip_packet
import network_interface
import raw_socket
import utils
var IPPROTO_IP {.importc: "IPPROTO_IP", header: "<netinet/in.h>".}: cint
var IP_TTL {.importc: "IP_TTL", header: "<netinet/in.h>".}: cint
const Timeout = 3000
type
ConnectAttempt* = ref object
srcIp: IpAddress
srcPort: Port
dstIp: IpAddress
dstPorts: seq[Port]
firewallRules: seq[string]
AcceptAttempt* = ref object
srcIp: IpAddress
srcPort: Port
dstIp: IpAddress
dstPorts: seq[Port]
seqNums: seq[uint32]
firewallRules: seq[string]
future: Future[AsyncSocket]
TcpSyniPuncher* = ref object
connectAttempts: seq[ConnectAttempt]
acceptAttempts: seq[AcceptAttempt]
PunchProgressCb* = proc (seqNums: seq[uint32]) {.async.}
PunchHoleError* = object of ValueError
proc makeFirewallRule(srcIp: IpAddress, srcPort: Port,
dstIp: IpAddress, dstPort: Port): string =
result = fmt"""-w \
-d {srcIp} \
-p icmp \
--icmp-type time-exceeded \
-m conntrack \
--ctstate RELATED \
--ctproto tcp \
--ctorigsrc {srcIp} \
--ctorigsrcport {srcPort.int} \
--ctorigdst {dstIp} \
--ctorigdstport {dstPort.int} \
-j DROP"""
proc iptablesInsert(chain: string, rule: string) {.async.} =
let firewall_cmd = fmt"iptables -I {chain} {rule}"
discard await asyncExecCmd(firewall_cmd)
proc iptablesDelete(chain: string, rule: string) {.async.} =
let firewall_cmd = fmt"iptables -D {chain} {rule}"
discard await asyncExecCmd(firewall_cmd)
proc addFirewallRules[T](attempt: T) {.async.} =
for dstPort in attempt.dstPorts:
let rule = makeFirewallRule(attempt.srcIp, attempt.srcPort,
attempt.dstIp, dstPort)
try:
await iptablesInsert("INPUT", rule)
attempt.firewallRules.add(rule)
except OSError as e:
echo "cannot add firewall rule: ", e.msg
raise newException(PunchHoleError, e.msg)
proc deleteFirewallRules[T](attempt: T) {.async.} =
for rule in attempt.firewallRules:
# FIXME: close sock?
try:
await iptablesDelete("INPUT", rule)
except OSError:
# At least we tried
discard
proc injectTcpPacket(rawFd: AsyncFD, ipPacket: IpPacket) {.async.} =
assert(ipPacket.protocol == tcp)
try:
let packet = serialize(ipPacket)
var sockaddr: Sockaddr_storage
var sockaddrLen: SockLen
toSockAddr(ipPacket.ipAddrDst, ipPacket.tcpPortDst, sockaddr, sockaddrLen)
await rawFd.sendTo(packet.cstring, packet.len,
cast[ptr SockAddr](addr sockaddr), sockaddrLen)
except OSError as e:
raise newException(PunchHoleError, e.msg)
proc captureSeqNumbers(attempt: ConnectAttempt, cb: PunchProgressCb) {.async.} =
# FIXME: timeout?
let iface = getNetworkInterface(attempt.srcIp)
let captureFd = setupEthernetCapturingSocket(iface)
var seqNums = newSeq[uint32]()
while seqNums.len < attempt.dstPorts.len:
let packet = await captureFd.recv(4000)
if packet == "":
break
let parsed = parseEthernetPacket(packet)
if parsed.protocol == tcp and
parsed.ipAddrSrc == attempt.srcIp and
parsed.tcpPortSrc.int == attempt.srcPort.int and
parsed.ipAddrDst == attempt.dstIp and
parsed.tcpFlags == {SYN}:
for port in attempt.dstPorts:
if parsed.tcpPortDst.int == port.int:
seqNums.add(parsed.tcpSeqNumber)
break
closeSocket(captureFd)
await cb(seqNums)
proc captureAndResendAck(attempt: ConnectAttempt) {.async.} =
let iface = getNetworkInterface(attempt.srcIp)
let captureFd = setupEthernetCapturingSocket(iface)
let injectFd = setupTcpInjectingSocket()
block loops:
while true:
let packet = await captureFd.recv(4000)
if packet == "":
break
var parsed = parseEthernetPacket(packet)
if parsed.protocol == tcp and
parsed.ipAddrSrc == attempt.srcIp and
parsed.tcpPortSrc.int == attempt.srcPort.int and
parsed.ipAddrDst == attempt.dstIp and
parsed.tcpFlags == {ACK}:
for port in attempt.dstPorts:
if parsed.tcpPortDst.int == port.int:
parsed.ipTTL = 64
echo &"[{parsed.ipAddrSrc}:{parsed.tcpPortSrc.int} -> {parsed.ipAddrDst}:{parsed.tcpPortDst}, SEQ {parsed.tcpSeqNumber}] resending ACK with TTL {parsed.ipTTL}"
await injectFd.injectTcpPacket(parsed)
break loops
closeSocket(captureFd)
closeSocket(injectFd)
proc initPuncher*(): TcpSyniPuncher =
randomize()
TcpSyniPuncher()
proc findConnectAttempt(puncher: TcpSyniPuncher, srcIp: IpAddress,
srcPort: Port, dstIp: IpAddress,
dstPorts: seq[Port]): int =
for (index, attempt) in puncher.connectAttempts.pairs():
if attempt.srcIp == srcIp and attempt.srcPort == srcPort and
attempt.dstIp == dstIp and
attempt.dstPorts.any(proc (p: Port): bool = p in dstPorts):
return index
return -1
proc findAcceptAttempt(puncher: TcpSyniPuncher, srcIp: IpAddress,
srcPort: Port, dstIp: IpAddress,
dstPorts: seq[Port]): int =
for (index, attempt) in puncher.acceptAttempts.pairs():
if attempt.srcIp == srcIp and attempt.srcPort == srcPort and
attempt.dstIp == dstIp and
attempt.dstPorts.any(proc (p: Port): bool = p in dstPorts):
return index
return -1
proc findAcceptAttemptsByLocalAddr(puncher: TcpSyniPuncher, address: IpAddress,
port: Port): seq[AcceptAttempt] =
for attempt in puncher.acceptAttempts:
if attempt.srcIp == address and attempt.srcPort == port:
result.add(attempt)
proc predictPortRange(dstPorts: seq[Port]): seq[Port] =
# TODO: do real port prediction
result = newSeq[Port](1)
let basePort = min(dstPorts[1].uint16,
uint16.high - (result.len - 1).uint16)
for i in 0 .. result.len - 1:
result[i] = Port(basePort + i.uint16)
proc cleanup*(puncher: TcpSyniPuncher) {.async.} =
while puncher.connectAttempts.len() != 0:
await puncher.connectAttempts.pop().deleteFirewallRules()
while puncher.acceptAttempts.len() != 0:
await puncher.connectAttempts.pop().deleteFirewallRules()
proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port,
future: Future[AsyncSocket]) {.async.} =
let sock = newAsyncSocket()
sock.setSockOpt(OptReuseAddr, true)
sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2)
echo &"doConnect {srcIp}:{srcPort} -> {dstIp}:{dstPort}"
sock.bindAddr(srcPort, $srcIp)
try:
await sock.connect($dstIp, dstPort)
sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 64)
future.complete(sock)
except OSError as e:
echo &"connection {srcIP}:{srcPort.int} -> {dstIp}:{dstPort.int} failed: ", e.msg
sock.close()
proc connect*(puncher: TcpSyniPuncher, srcPort: Port, dstIp: IpAddress,
dstPorts: seq[Port],
progressCb: PunchProgressCb): Future[AsyncSocket] {.async.} =
let localIp = getPrimaryIPAddr(dstIp)
if puncher.findConnectAttempt(localIp, srcPort, dstIp, dstPorts) != -1:
raise newException(PunchHoleError, "hole punching for given parameters already active")
let attempt = ConnectAttempt(srcIp: localIp, srcPort: srcPort, dstIp: dstIp,
dstPorts: predictPortRange(dstPorts))
puncher.connectAttempts.add(attempt)
await attempt.addFirewallRules()
asyncCheck attempt.captureSeqNumbers(progressCb)
asyncCheck attempt.captureAndResendAck()
try:
let connectFuture = newFuture[AsyncSocket]("connect")
for dstPort in attempt.dstPorts:
asyncCheck doConnect(attempt.srcIp, attempt.srcPort, attempt.dstIp,
dstPort, connectfuture)
await connectFuture or sleepAsync(Timeout)
await attempt.deleteFirewallRules()
puncher.connectAttempts.del(puncher.connectAttempts.find(attempt))
if connectFuture.finished():
result = connectFuture.read()
else:
raise newException(PunchHoleError, "timeout")
except OSError as e:
raise newException(PunchHoleError, e.msg)
proc doAccept(puncher: TcpSyniPuncher, srcIp: IpAddress,
srcPort: Port) {.async.} =
let sock = newAsyncSocket()
sock.setSockOpt(OptReuseAddr, true)
sock.bindAddr(srcPort, $(srcIp))
sock.listen()
while true:
let acceptFuture = sock.accept()
await acceptFuture or sleepAsync(Timeout)
if acceptFuture.finished():
let peer = acceptFuture.read()
let (peerAddr, peerPort) = peer.getPeerAddr()
let peerIp = parseIpAddress(peerAddr)
let i = puncher.findAcceptAttempt(srcIp, srcPort, peerIp, @[peerPort])
if i == -1:
echo "Accepted connection, but no attempt found. Discarding."
peer.close()
continue
else:
let attempt = puncher.acceptAttempts[i]
attempt.future.complete(peer)
let attempts = puncher.findAcceptAttemptsByLocalAddr(srcIp, srcPort)
# FIXME: should attempts have timestamps, so we can decide here which ones to delete?
if attempts.len() <= 1:
break
sock.close()
proc injectSynPackets(attempt: AcceptAttempt) {.async.} =
let rawFd = setupTcpInjectingSocket()
for dstPort in attempt.dstPorts:
let synOut = IpPacket(protocol: tcp, ipAddrSrc: attempt.srcIp,
ipAddrDst: attempt.dstIp, ipTTL: 2,
tcpPortSrc: attempt.srcPort, tcpPortDst: dstPort,
tcpSeqNumber: rand(uint32), tcpAckNumber: 0,
tcpFlags: {SYN}, tcpWindowSize: 1452 * 10)
echo &"[{synOut.ipAddrSrc}:{synOut.tcpPortSrc} -> {synOut.ipAddrDst}:{synOut.tcpPortDst}, SEQ {synOut.tcpSeqNumber}] injecting outgoing SYN"
await rawFd.injectTcpPacket(synOut)
for seqNum in attempt.seqNums:
let synIn = IpPacket(protocol: tcp, ipAddrSrc: attempt.dstIp,
ipAddrDst: attempt.srcIp, ipTTL: 64,
tcpPortSrc: dstPort,
tcpPortDst: attempt.srcPort,
tcpSeqNumber: seqNum, tcpAckNumber: 0,
tcpFlags: {SYN}, tcpWindowSize: 1452 * 10)
echo &"[{synIn.ipAddrSrc}:{synIn.tcpPortSrc} -> {synIn.ipAddrDst}:{synIn.tcpPortDst}, SEQ {synIn.tcpSeqNumber}] injecting incoming SYN"
await rawFd.injectTcpPacket(synIn)
closeSocket(rawFd)
proc accept*(puncher: TcpSyniPuncher, srcPort: Port, dstIp: IpAddress,
dstPorts: seq[Port],
seqNums: seq[uint32]): Future[AsyncSocket] {.async.} =
let localIp = getPrimaryIPAddr(dstIp)
let existingAttempts = puncher.findAcceptAttemptsByLocalAddr(localIp, srcPort)
if existingAttempts.len() == 0:
echo &"accepting connections from {dstIp}:{dstPorts[0].int}"
asyncCheck puncher.doAccept(localIp, srcPort)
else:
for a in existingAttempts:
if a.dstIp == dstIp and
a.dstPorts.any(proc (p: Port): bool = p in dstPorts):
raise newException(PunchHoleError, "hole punching for given parameters already active")
try:
let attempt = AcceptAttempt(srcIp: localIp, srcPort: srcPort, dstIp: dstIp,
dstPorts: predictPortRange(dstPorts),
seqNums: seqNums,
future: newFuture[AsyncSocket]("accept"))
puncher.acceptAttempts.add(attempt)
await attempt.addFirewallRules()
await attempt.injectSynPackets()
await attempt.future or sleepAsync(Timeout)
await attempt.deleteFirewallRules()
puncher.acceptAttempts.del(puncher.acceptAttempts.find(attempt))
if attempt.future.finished():
result = attempt.future.read()
else:
raise newException(PunchHoleError, "timeout")
except OSError as e:
raise newException(PunchHoleError, e.msg)