-- this code is based on src/tftp.lua from https://github.com/ldrumm/tufty -- which is distributed under the MIT License -- Starting upstream revision 3cb95c869e2fe74cc61ca303d88af6c5daad6734 -- -- Changes made since then are mostly to make it work better with -- luasocket --[[ This package provides intefaces for handling TFTP requests as specified by rfc1350, rfc1782, rfc2347, rfc2348 and (partial support for rfc2349) It should work on Standard PC servers as well as small home routers running OpenWRT or similar. ]] local tftp = {} local TIMEOUT = 5 local ACK_RETRIES = 10 local TFTP_PORT = 69 local BLKSIZE = 512 local OP_RRQ = 1 local OP_WRQ = 2 local OP_DATA = 3 local OP_ACK = 4 local OP_ERR = 5 local OP_OACK = 6 local ACKSIZE = 4 local ERR_NOTFOUND = 1 local ERR_ACCESS = 2 local ERR_ALLOC = 3 local ERR_ILLEGAL_OP = 4 local ERR_UNKNOWN_ID = 5 local ERR_EXISTS = 6 local ERR_WHO = 7 --rfc2347 specifies the options extension. --rfc2348 specifies the blksize extension --rfc2349 specifies the timeout and tsize extensions local SUPPORTED_OPTIONS = {blksize=true, timeout=true, tsize=true} --Use penlight's prettyprinter if available pcall(require, 'pl') local log = pretty and pretty.dump or print local time = (function() return require("socket").gettime end)() local poll = (function() --[[ ``poll`` is expected to accept a table of sockets keyed by backend file descriptor formatted as follows: { [low level socket lib file descriptor]{ fd=(low level socket lib fd) wantread=(bool) wantwrite=(bool) ...arbitrary ignored extra fields }, ... } it returns a list of fds formatted as follows { [low level socket lib file descriptor]{ fd= [low level socket lib file descriptor], readable=(bool), writable=(bool) } ... } ]] local luasocket = require("socket") return function(fds, timeout) local wantread = {} local wantwrite = {} for _, fd in pairs(fds) do fd.readable=false fd.writeable=false if fd.wantwrite then wantwrite[#wantwrite + 1] = fd.fd end if fd.wantread then wantread[#wantread + 1] = fd.fd end end local readable, writeable, timedout = luasocket.select(wantread, wantwrite, timeout) if timedout then return nil end local ready = {} for _, fd in ipairs(readable) do ready[fd] = ready[fd] or { fd = fd } ready[fd].readable = true end for _, fd in ipairs(writeable) do ready[fd] = ready[fd] or { fd = fd } ready[fd].writeable = true end return ready end end)() local function UDPSocket() --[[ We want to support the basic functionality required for TFTP operation over UDP. This wraps only the required functionality and in no way represents a complete UDP socket implementation. see http://w3.impa.br/~diego/software/luasocket/udp.html for the luasocket UDP API ]] local okay, luasocket = pcall(require, "socket") if okay then log("using luasocket") return { fd = luasocket.udp(), bind = function(self, address, port) return self.fd:setsockname(address, port) end, sendto = function(self, data, address, port) return self.fd:sendto(data, address, port) end, recvfrom = function(self, length) return self.fd:receivefrom(length) end, close = function(self) return self.fd:close() end, } end end local function is_netascii(s) --[[Check whether a string contains only characters from the RFC764 ascii subset. see https://tools.ietf.org/html/rfc764#page-11 ]] local ctrls = {[0]=1, [10]=1, [13]=1, [7]=1, [8]=1, [9]=1, [11]=1, [12]=1} for i=1, #s do local byte = s:sub(i, i):byte() if (byte < 31 and ctrls[byte] == nil) or byte > 127 then return false end end return true end local function create_opcode(val) if val < 0 or val > 2^16-1 then error("opcodes must fit into a 16bit integer") end local high = math.floor(val / 256) -- RFC1350 doesn't mention byte order. Assume network order (big-endian). return string.char(high, val - (high * 256)) end local function parse_opcode(packet) local opcode = string.byte(packet:sub(2, 2)) --assume big endian return ({"RRQ", "WRQ", "DATA", "ACK", "ERROR", "OACK"})[opcode] end function tftp:handle_RRQ(socket, host, port, source, options) local blksize = options and tonumber(options.blksize) or BLKSIZE local timeout_secs = options and tonumber(options.timeout) or TIMEOUT --rfc2349 timout option local length = options and tonumber(options.length) local tid = 1 local time = time local started = time() local err = self.ERROR local done = false local error, success = error, error -- to terminate the coroutine immediately, we raise an error local yield = coroutine.yield return coroutine.create(function() if options then --[[The handler coroutine should not start running until we know the client is ready. This depends on whether the client has requested rfc2347 options and responded to an OACK. Without the options extension request, the client can be responded to immediately.]] local acked, timeout = false, time() + timeout_secs assert(socket:sendto(self.OACK(options), host, port)) log(("sent OACK to %s:%d"):format(host, port)) repeat yield(true, false) local msg, port, host = socket:recvfrom(ACKSIZE) if self.parse_ACK(msg) == 0 then acked = true end timedout = time() > timeout until acked or timedout if timedout then error("Request timed out waiting for OACK response") end end log(("coroutine started on %s:%s/"):format(host, port)) while not done do if tid >= 2^16 then socket:sendto(err("File too big."), host, port) error("File too big.") end local okay, continue, data = pcall(source, blksize) if not okay then packet = socket:sendto(err("An unknown error occurred"), host, port) error("generator failure") end if data == nil and not continue then done = true end if data == nil and continue then --[[The generator ``source`` can be async and return `true, nil` if no data is ready, but things are going well. ]] yield(false, true) end socket:sendto(self.DATA(data, tid), host, port) --[[ Now check for an ACK. RFC1350 requires that for every packet sent, an ACK is received before the next packet can be sent. ]] local acked local retried = 0 local timeout = time() + timeout_secs local timedout = false repeat yield(true, false) -- we need to wait until the socket is readable again local ack, ackhost, ackport = socket:recvfrom(ACKSIZE) if ackhost ~= host or ackport ~= port or self.parse_ACK(ack) ~= tid then --[[https://tools.ietf.org/html/rfc1350#page-5 "If a source TID does not match, the packet should be discarded as erroneously sent from somewhere else. An error packet should be sent to the source of the incorrect packet, while not disturbing the transfer." ]] socket:sendto(err(ERR_UNKNOWN_ID), ackhost, ackport) yield(true, false) else acked = true end retried = retried + 1 timedout = time() > timeout until acked or retried > ACK_RETRIES or timedout if timedout or retried > ACK_RETRIES then --There doesn't seem to be a standard error for timeout. socket:sendto(err("Ack timeout"), host, port) error("Timeout waiting for ACK") end --Okay, we've been acked in reasonable time. tid = tid + 1 if done then success() end yield(true, true) end end) end function tftp:handle_WRQ(socket, host, port, sink) error"Not Implemented" end function tftp:listen(rrq_generator_callback, wrq_generator_callback, hosts, port, logfile) --[[-- Listen for TFTP requests on UDP ```bind``:`port`` (0.0.0.0:69 by default) and get data from / send data to user-generated source/sink functions. Data is generated/received by functions returned by the the user-supplied ``rrq_generator_callback``/``wrq_generator_callback`` factory functions. For every valid request packet received the generator function returned by ``xrq_generator_callback`` will be called expecting data. When called with a single argument, (the requested resource as a C-style string (no embedded NUL chars)) ``xrq_generator_callback`` should return a source or sink function that: (SOURCE) takes a single argument of the required data length in bytes and returns blocks of data until complete. must return as follows `true, data` on success `true, nil` on wouldblock but should retry next round, `false` on finished (SINK) takes two arguments ``data`` to write ``done`` (truthy), whether all data has been received and backends can cleanup. The (SOURCE) model therefore supports both blocking and non-blocking behaviour. If the given function blocks, however, it will block the whole process as Lua is single threaded. That may or may not be acceptable depending on your needs. If the requested resource is invalid or other termination conditions are met, (SOURCE) and (SINK) functions should raise an error. @return This method never returns unless interrupted. ]] local function create_handler(callbacks, request, requestsocket, host, port) --[[ Given a parsed request, instantiate the generator function from the given callbacks, and create a new coroutine to be called when the state of the handler's new socket changes to available. On success, returns a table of the form: ```{ handler=coroutine to call, socket= new socket on a random port on which all new communication will happen, fd=socket.fd as above fd host=remote host, port = remote port, request = the original parsed request, including accepted options, if any. }``` On error, responds to the client with an ERROR packet, and returns nil. ]] local okay, generator, tsize = pcall(callbacks[request.opcode], request.filename, host, port) if not okay then requestsocket:sendto(self.ERROR(ERR_NOTFOUND), host, port) return nil else if request.options then request.options.tsize = request.options.tsize and tostring(tsize) for k, v in pairs(request.options) do if not SUPPORTED_OPTIONS[k] then request.options[k] = nil end end else --RFC1350 requires WRQ requests to be responded to with a zero TID before transfer commences, --but when responding to an options request, it is dropped. if request.opcode == 'WRQ' then requestsocket:sendto(self.ACK(0), host, port) end end local handlersocket = UDPSocket() local handler = self['handle_' .. request.opcode](self, handlersocket, host, port, generator, request.options) return { handler=handler, socket=handlersocket, fd=handlersocket.fd, host=host, port=port, request=request, wantread=false, wantwrite=true, } end end local function accept(socket) --[[ Read an incoming request from ``socket``, parse, and ACK as appropriate. If the request is invalid, responds to the client with error and returns `nil` otherwise returns the parsed request. ]] local msg, host, port = socket:recvfrom(-1) if msg ~= false then local okay, xRQ = pcall(self.parse_XRQ, msg) if not okay then return nil else return host, port, xRQ end end end local socket = UDPSocket() local user_generator_callbacks = {RRQ=rrq_generator_callback, WRQ=wrq_generator_callback} local port = port or TFTP_PORT local logfile = logfile or io.stderr --listen on all given addresses, default to localhost if not given for i, address in pairs((type(hosts) == 'table' and hosts) or (hosts ~= nil and{hosts}) or {'127.0.0.1'}) do socket:bind(address, port) end --[[The main event loop does two things: 1. Accepts new connections. 2. Handles events occurring on all sockets by dispatching to a handler coroutine. 3. Removes finished requests from the queue and destroys the sockets. ]] local handlers = {[socket.fd]={fd=socket.fd, socket=socket, listener=true, wantread=true}} while true do ready_fds = poll(handlers) do local n = 0 for _ in pairs(ready_fds) do n = n + 1 end -- log(('There are %d sockets ready'):format(n)) end for fd, status in pairs(ready_fds) do -- pretty.dump(ready_fds) -- log(('There are %d sockets ready'):format(#ready_fds)) ready = handlers[fd] if ready.listener and status.readable then --we've got a listener and should accept a new connection local host, port, request = accept(ready.socket) if host ~= nil then log(("accepted new %s request - %s:%s/%s"):format(request.opcode, host, port, request.filename)) local handler = create_handler( user_generator_callbacks, request, ready.socket, host, port ) if handler then handlers[handler.socket.fd] = handler end end elseif (status.readable or status.writeable) and ready.handler then --We've received an event on a socket associated with an existing handler coroutine. local co_state = coroutine.status(ready.handler) local okay, wantread, wantwrite if co_state ~= 'dead' then if (ready.wantread and status.readable) or (ready.wantwrite and status.writeable) then okay, wantread, wantwrite = coroutine.resume(ready.handler) ready.wantread = wantread ready.wantwrite = wantwrite end end if (not okay) or co_state == 'dead' then --- the handler is finished; cleanup ready.socket:close() handlers[ready.fd] = nil ready.fd = nil ready = nil end end end end end --[[ RRQ/ZRQ read/write request packets https://tools.ietf.org/html/rfc1350 2 bytes string 1 byte string 1 byte ------------------------------------------------ | Opcode | Filename | 0 | Mode | 0 | ------------------------------------------------ Figure 5-1: RRQ/WRQ packet ]] function tftp.RRQ(filename) -- RFC1350:"The mail mode is obsolete and should not be implemented or used." -- We don't support netascii, which leaves 'octet' mode only return table.concat({create_opcode(OP_RRQ), filename, '\0', "octet", '\0'}, '') end function tftp.parse_XRQ(request) local opcode = assert(parse_opcode(request), "Invalid opcode") assert(({RRQ=true, XRQ=true})[opcode], "Not an xRQ") assert(request:sub(#request) == '\0', "Invalid request: expected ASCII NUL terminated request") local cstrings = {} function zero_iter(s) local pos = 1 return function() --This is ugly. Lua 5.2 handles embedded NUL bytes in string.gmatch, --but vanilla Lua5.1 doesn't match correctly and luajit can't seem to parse them for i=pos, #s do if s:byte(i) == 0 then local sub = s:sub(pos, i-1) pos = i+1 return sub end end end end for s in zero_iter(request:sub(3)) do cstrings[#cstrings+1] = s end assert(#cstrings >= 2) local filename = assert(is_netascii(cstrings[1]) and cstrings[1], "Requested filename must be netascii") local mode = assert(({netascii='netascii', octet='octet'})[cstrings[2]]) local options if #cstrings > 2 then options = {} assert(#cstrings % 2 == 0) for i=3, #cstrings, 2 do --[[ RFC1782, and 3247 require case insensitive comparisons. We normalize them to lowercase with the consequence that duplicate keys are replaced which are forbidden by the standard anyway. ]] options[cstrings[i]:lower()] = cstrings[i+1]:lower() end end return {opcode=opcode, filename=filename, mode=mode, options=options} end --[[ ACK functions 2 bytes 2 bytes --------------------- | Opcode | Block # | --------------------- Figure 5-3: ACK packet ]] function tftp.parse_ACK(ack) --get the sequence number from an ACK or raise a error if not valid assert(#ack == ACKSIZE, "invalid ack") assert(parse_opcode(ack) == 'ACK', "invalid ack") -- extract the low and high order bytes and convert to an integer local high, low = ack:byte(3, 4) return (high * 256) + low end --[[ +-------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+ | opc | opt1 | 0 | value1 | 0 | optN | 0 | valueN | 0 | +-------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+ ]] function tftp.OACK(options) local stropts = {} for k, v in pairs(options) do assert(is_netascii(k)) stropts[#stropts+1] = k .. '\0' .. v .. '\0' end return create_opcode(OP_OACK) .. table.concat(stropts, '') end function tftp.ACK(tid) return table.concat({create_opcode(OP_ACK), create_opcode(tid)}, '') end --[[ DATA functions 2 bytes 2 bytes n bytes ---------------------------------- | Opcode | Block # | Data | ---------------------------------- Figure 5-2: DATA packet ]] function tftp.DATA(data, tid) local opcode = create_opcode(OP_DATA) local block = create_opcode(tid) return table.concat({opcode, block, data}, '') end function tftp.parse_DATA(data) assert(#data <= 512, "tftp data packets must be 512 bytes or less") assert(parse_opcode(data) == OP_DATA, "Invalid opcode") return {tid=parse_opcode(data:sub(3, 4)), data=data:sub(5)} end --[[ ERROR Functions 2 bytes 2 bytes string 1 byte ----------------------------------------- | Opcode | ErrorCode | ErrMsg | 0 | ----------------------------------------- Figure 5-4: ERROR packet ]] function tftp.ERROR(err) local defined_errors = { --https://tools.ietf.org/html/rfc1350#page-10 [0] = type(err) == 'string' and err or "Not defined", "File not found.", "Access violation.", "Disk full or allocation exceeded.", "Illegal TFTP operation.", "Unknown transfer ID.", "File already exists.", "No such user.", } local errno = type(err) == 'string' and 0 or err return table.concat({ create_opcode(OP_ERR), create_opcode(errno), defined_errors[errno], '\0' }, '') end function tftp.parse_ERROR(err) assert(parse_opcode(err) == OP_ERR) local error_code = parse_opcode(err:sub(3, 4)) return {errcode=error_code, errmsg=err:sub(5, #err-1)} end return tftp