Jump to content

Module:Punycode

Documentation for this module may be created at Module:Punycode/doc

----------------------------------------------------------------
-- Module:Punycode  –  RFC 3492 implementation for Lua 5.1 (Scribunto)
-- * punycode.encode(label)      – Punycode for ONE label (no dots)
-- * punycode.decode(label)      – back to Unicode
-- * punycode.toASCII(domain)    – full domain → IDNA (handles dots, xn--)
-- * punycode.toUnicode(domain)  – IDNA domain → Unicode
----------------------------------------------------------------
local punycode = {}

----------------------------------------------------------------
-- Caches (persist only for a single page render)
----------------------------------------------------------------
local encodeCache, decodeCache = {}, {}

----------------------------------------------------------------
-- Constants from RFC 3492
----------------------------------------------------------------
local base, tmin, tmax = 36, 1, 26
local skew, damp       = 38, 700
local initial_bias     = 72
local initial_n        = 128           -- 0x80
local delimiter        = '-'           -- ASCII hyphen

----------------------------------------------------------------
-- UTF-8 helpers (mw.ustring exists in Scribunto; falls back otherwise)
----------------------------------------------------------------
local us = mw and mw.ustring
local function toCodePoints(s)
    if s == "" then return {} end
    if us then
        local cps, i = {}, 1
        for ch in us.gmatch(s, ".") do
            cps[i] = us.codepoint(ch)
            i = i + 1
        end
        return cps
    end
    -- plain Lua 5.1 fallback (minimal; good enough for Punycode paths)
    local cps, i, len = {}, 1, #s
    while i <= len do
        local b1 = s:byte(i)
        if b1 < 0x80 then
            cps[#cps + 1], i = b1, i + 1
        elseif b1 < 0xE0 then
            local b2 = s:byte(i + 1)
            cps[#cps + 1] = (b1 - 0xC0) * 0x40 + (b2 - 0x80)
            i = i + 2
        elseif b1 < 0xF0 then
            local b2, b3 = s:byte(i + 1, i + 2)
            cps[#cps + 1] =
                  (b1 - 0xE0) * 0x1000
                + (b2 - 0x80) * 0x40
                + (b3 - 0x80)
            i = i + 3
        else
            local b2, b3, b4 = s:byte(i + 1, i + 3)
            cps[#cps + 1] =
                  (b1 - 0xF0) * 0x40000
                + (b2 - 0x80) * 0x1000
                + (b3 - 0x80) * 0x40
                + (b4 - 0x80)
            i = i + 4
        end
    end
    return cps
end

local function fromCodePoints(cps)
    if us then
        local out = {}
        for i = 1, #cps do out[i] = us.char(cps[i]) end
        return table.concat(out)
    end
    local function cp2utf8(cp)
        if cp < 0x80   then return string.char(cp) end
        if cp < 0x800  then
            return string.char(
                0xC0 + math.floor(cp / 0x40),
                0x80 + (cp % 0x40)
            )
        end
        if cp < 0x10000 then
            return string.char(
                0xE0 + math.floor(cp / 0x1000),
                0x80 + (math.floor(cp / 0x40) % 0x40),
                0x80 + (cp % 0x40)
            )
        end
        return string.char(
            0xF0 + math.floor(cp / 0x40000),
            0x80 + (math.floor(cp / 0x1000) % 0x40),
            0x80 + (math.floor(cp / 0x40)   % 0x40),
            0x80 + (cp % 0x40)
        )
    end
    local out = {}
    for i = 1, #cps do out[i] = cp2utf8(cps[i]) end
    return table.concat(out)
end

----------------------------------------------------------------
-- RFC 3492 helpers
----------------------------------------------------------------
local function digitToBasic(d)
    return string.char(d < 26 and (d + 97) or (d - 26 + 48)) -- a-z / 0-9
end

local function basicToDigit(byte)
    if byte >= 48 and byte <= 57  then return byte - 22 end -- '0'-'9' → 26-35
    if byte >= 65 and byte <= 90  then return byte - 65 end -- 'A'-'Z'
    if byte >= 97 and byte <= 122 then return byte - 97 end -- 'a'-'z'
    return base                                             -- invalid
end

local function adapt(delta, numpoints, first)
    delta = first and math.floor(delta / damp) or math.floor(delta / 2)
    delta = delta + math.floor(delta / numpoints)
    local k = 0
    while delta > ((base - tmin) * tmax) / 2 do
        delta = math.floor(delta / (base - tmin))
        k = k + base
    end
    return k + math.floor(((base - tmin + 1) * delta) / (delta + skew))
end

----------------------------------------------------------------
-- Single-label Punycode encode / decode
----------------------------------------------------------------
local function isASCII(str)
    for i = 1, #str do if str:byte(i) > 127 then return false end end
    return true
end

function punycode.encode(label)
    if not label or label == "" then return "" end
    label = label:gsub("%.$", "")              -- strip *trailing* dot
    if label:find("%.") then
        error("punycode.encode: one label at a time (no dots)")
    end
    label = (us and us.lower or string.lower)(label)
    if encodeCache[label] then return encodeCache[label] end

    local cp_arr = toCodePoints(label)
    local out, n, delta, bias = {}, initial_n, 0, initial_bias
    local basic = 0

    -- copy ASCII code points
    for _, cp in ipairs(cp_arr) do
        if cp < 128 then
            out[#out + 1] = string.char(cp)
            basic = basic + 1
        end
    end
    if basic > 0 and basic < #cp_arr then out[#out + 1] = delimiter end

    local h = basic
    while h < #cp_arr do
        local m = 0x7FFFFFFF
        for _, cp in ipairs(cp_arr) do
            if cp >= n and cp < m then m = cp end
        end
        delta = delta + (m - n) * (h + 1)
        n = m
        for _, cp in ipairs(cp_arr) do
            if cp < n then
                delta = delta + 1
            elseif cp == n then
                local q, k = delta, base
                while true do
                    local t
                    if     k <= bias         then t = tmin
                    elseif k >= bias + tmax  then t = tmax
                    else                         t = k - bias end
                    if q < t then break end
                    out[#out + 1] = digitToBasic(t + (q - t) % (base - t))
                    q = math.floor((q - t) / (base - t))
                    k = k + base
                end
                out[#out + 1] = digitToBasic(q)
                bias  = adapt(delta, h + 1, h == basic)
                delta = 0
                h     = h + 1
            end
        end
        delta = delta + 1
        n     = n + 1
    end

    local res = table.concat(out)
    encodeCache[label] = res
    return res
end

function punycode.decode(label)
    if not label or label == "" then return "" end
    if decodeCache[label] then return decodeCache[label] end

    local cps, d = {}, (label:find(delimiter, 1, true) or 0)
    for i = 1, d - 1 do cps[#cps + 1] = label:byte(i) end

    local n, i_val, bias = initial_n, 0, initial_bias
    local pos, len       = (d > 0) and (d + 1) or 1, #label

    while pos <= len do
        local oldi, w, k = i_val, 1, base
        while true do
            local digit = basicToDigit(label:byte(pos))
            pos = pos + 1
            i_val = i_val + digit * w
            local t
            if     k <= bias         then t = tmin
            elseif k >= bias + tmax  then t = tmax
            else                         t = k - bias end
            if digit < t then break end
            w = w * (base - t)
            k = k + base
        end
        bias = adapt(i_val - oldi, #cps + 1, oldi == 0)
        n    = n + math.floor(i_val / (#cps + 1))
        i_val = i_val % (#cps + 1)
        table.insert(cps, i_val + 1, n)
        i_val = i_val + 1
    end

    local res = fromCodePoints(cps)
    decodeCache[label] = res
    return res
end

----------------------------------------------------------------
-- Domain-level helpers  (the requested FIX)
----------------------------------------------------------------
local function splitLabels(domain)
    local labels, i = {}, 1
    for label in domain:gmatch("([^%.]+)") do
        labels[i], i = label, i + 1
    end
    return labels
end

local function stripTrailingDot(s)
    return (s:sub(-1) == '.' and s:sub(1, -2) or s),
           (s:sub(-1) == '.')
end

-- Unicode → ASCII/IDNA (strips dot *before* encoding, encodes each label separately)
function punycode.toASCII(domain)
    if not domain or domain == "" then return "" end
    local trailing
    domain, trailing = stripTrailingDot(domain)

    local ascii = {}
    for _, lbl in ipairs(splitLabels(domain)) do
        ascii[#ascii + 1] = isASCII(lbl) and lbl
                         or ("xn--" .. punycode.encode(lbl))
    end
    local res = table.concat(ascii, ".")
    return trailing and (res .. ".") or res
end

-- ASCII/IDNA → Unicode (each label separately)
function punycode.toUnicode(domain)
    if not domain or domain == "" then return "" end
    local trailing
    domain, trailing = stripTrailingDot(domain)

    local uni = {}
    for _, lbl in ipairs(splitLabels(domain)) do
        if lbl:sub(1, 4):lower() == "xn--" then
            uni[#uni + 1] = punycode.decode(lbl:sub(5))
        else
            uni[#uni + 1] = lbl
        end
    end
    local res = table.concat(uni, ".")
    return trailing and (res .. ".") or res
end

return punycode