summaryrefslogtreecommitdiff
path: root/src/simd/base64.zig
blob: 88b97bb039225ec90254cb478768b5d25437dc0f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
const std = @import("std");
const options = @import("build_options");
const assert = std.debug.assert;
const scalar_decoder = @import("base64_scalar.zig").scalar_decoder;

const log = std.log.scoped(.simd_base64);

pub fn maxLen(input: []const u8) usize {
    if (comptime options.simd) return ghostty_simd_base64_max_length(
        input.ptr,
        input.len,
    );

    return maxLenScalar(input);
}

fn maxLenScalar(input: []const u8) usize {
    return scalar_decoder.calcSizeForSlice(scalarInput(input)) catch |err| {
        log.warn("failed to calculate base64 size for payload: {}", .{err});
        return 0;
    };
}

pub fn decode(input: []const u8, output: []u8) error{Base64Invalid}![]const u8 {
    if (comptime options.simd) {
        const res = ghostty_simd_base64_decode(
            input.ptr,
            input.len,
            output.ptr,
        );
        if (res < 0) return error.Base64Invalid;
        return output[0..@intCast(res)];
    }

    return decodeScalar(input, output);
}

fn decodeScalar(
    input_raw: []const u8,
    output: []u8,
) error{Base64Invalid}![]const u8 {
    const input = scalarInput(input_raw);
    const size = maxLenScalar(input);
    if (size == 0) return "";
    assert(output.len >= size);
    scalar_decoder.decode(
        output,
        scalarInput(input),
    ) catch return error.Base64Invalid;
    return output[0..size];
}

/// For non-SIMD enabled builds, we trim the padding from the end of the
/// base64 input in order to get identical output with the SIMD version.
fn scalarInput(input: []const u8) []const u8 {
    var i: usize = 0;
    while (input[input.len - i - 1] == '=') i += 1;
    return input[0 .. input.len - i];
}

// base64.cpp
extern "c" fn ghostty_simd_base64_max_length(
    input: [*]const u8,
    len: usize,
) usize;
extern "c" fn ghostty_simd_base64_decode(
    input: [*]const u8,
    len: usize,
    output: [*]u8,
) isize;

test "base64 maxLen" {
    const testing = std.testing;
    const len = maxLen("aGVsbG8gd29ybGQ=");
    try testing.expectEqual(11, len);
}

test "base64 decode" {
    const testing = std.testing;
    const alloc = testing.allocator;
    const input = "aGVsbG8gd29ybGQ=";
    const len = maxLen(input);
    const output = try alloc.alloc(u8, len);
    defer alloc.free(output);
    const str = try decode(input, output);
    try testing.expectEqualStrings("hello world", str);
}