summaryrefslogtreecommitdiff
path: root/src/simd/vt.cpp
blob: d8133abf76f9a073c70d183cf48d122c0702dffb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
// Generates code for every target that this compiler can support.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "simd/vt.cpp"  // this file
#include <hwy/foreach_target.h>           // must come before highway.h
#include <hwy/highway.h>

#include <simdutf.h>
#include <utf8.h>
#include <vector>

#include <simd/index_of.h>
#include <simd/vt.h>

HWY_BEFORE_NAMESPACE();
namespace ghostty {
namespace HWY_NAMESPACE {

namespace hn = hwy::HWY_NAMESPACE;

using T = uint8_t;

// Decode the UTF-8 text in input into output. Returns the number of decoded
// characters. This function assumes output is large enough.
//
// This function handles malformed UTF-8 sequences by inserting a
// replacement character (U+FFFD) and continuing to decode. This function
// will consume the entire input no matter what.
size_t DecodeUTF8(const uint8_t* HWY_RESTRICT input,
                  size_t count,
                  char32_t* output) {
  // Its possible for our input to be empty since DecodeUTF8UntilControlSeq
  // doesn't check for this.
  if (count == 0) {
    return 0;
  }

  // Assume no errors for fast path.
  const size_t decoded = simdutf::convert_utf8_to_utf32(
      reinterpret_cast<const char*>(input), count, output);
  if (decoded > 0) {
    return decoded;
  }

  // Errors in the UTF input, take a slow path and do a decode with
  // replacement (with U+FFFD). Note that simdutf doesn't have a
  // decode with replacement API:
  // https://github.com/simdutf/simdutf/issues/147
  //
  // Because of this, we use a separate library with heap allocation
  // that is much, much slower (the allocation is slower, the algorithm
  // is slower, etc.) This is just so we have something that works.
  // I want to replace this.
  std::vector<char> replacement_result;
  utf8::replace_invalid(input, input + count,
                        std::back_inserter(replacement_result), 0xFFFD);
  return DecodeUTF8(reinterpret_cast<const uint8_t*>(replacement_result.data()),
                    replacement_result.size(), output);
}

/// Decode the UTF-8 text in input into output until an escape
/// character is found. This returns the number of bytes consumed
/// from input and writes the number of decoded characters into
/// output_count.
///
/// This may return a value less than count even with no escape
/// character if the input ends with an incomplete UTF-8 sequence.
/// The caller should check the next byte manually to determine
/// if it is incomplete.
template <class D>
size_t DecodeUTF8UntilControlSeqImpl(D d,
                                     const T* HWY_RESTRICT input,
                                     size_t count,
                                     char32_t* output,
                                     size_t* output_count) {
  const size_t N = hn::Lanes(d);

  // Create a vector containing ESC since that denotes a control sequence.
  const hn::Vec<D> esc_vec = Set(d, 0x1B);

  // Compare N elements at a time.
  size_t i = 0;
  for (; i + N <= count; i += N) {
    // Load the N elements from our input into a vector.
    const hn::Vec<D> input_vec = hn::LoadU(d, input + i);

    // If we don't have any escapes we keep going. We want to accumulate
    // the largest possible valid UTF-8 sequence before decoding.
    // TODO(mitchellh): benchmark this vs decoding every time
    const auto esc_idx = IndexOfChunk(d, esc_vec, input_vec);
    if (!esc_idx) {
      continue;
    }

    // We have an ESC char, decode up to this point. We start by assuming
    // a valid UTF-8 sequence and slow-path into error handling if we find
    // an invalid sequence.
    *output_count = DecodeUTF8(input, i + esc_idx.value(), output);
    return i + esc_idx.value();
  }

  // If we have leftover input then we decode it one byte at a time (slow!)
  // using pretty much the same logic as above.
  if (i != count) {
    const hn::CappedTag<T, 1> d1;
    using D1 = decltype(d1);
    const hn::Vec<D1> esc1 = Set(d1, hn::GetLane(esc_vec));
    for (; i < count; ++i) {
      const hn::Vec<D1> input_vec = hn::LoadU(d1, input + i);
      const auto esc_idx = IndexOfChunk(d1, esc1, input_vec);
      if (!esc_idx) {
        continue;
      }

      *output_count = DecodeUTF8(input, i + esc_idx.value(), output);
      return i + esc_idx.value();
    }
  }

  // If we reached this point, its possible for our input to have an
  // incomplete sequence because we're consuming the full input. We need
  // to trim any incomplete sequences from the end of the input.
  const size_t trimmed_len =
      simdutf::trim_partial_utf8(reinterpret_cast<const char*>(input), i);
  *output_count = DecodeUTF8(input, trimmed_len, output);
  return trimmed_len;
}

size_t DecodeUTF8UntilControlSeq(const uint8_t* HWY_RESTRICT input,
                                 size_t count,
                                 char32_t* output,
                                 size_t* output_count) {
  const hn::ScalableTag<uint8_t> d;
  return DecodeUTF8UntilControlSeqImpl(d, input, count, output, output_count);
}

}  // namespace HWY_NAMESPACE
}  // namespace ghostty
HWY_AFTER_NAMESPACE();

// HWY_ONCE is true for only one of the target passes
#if HWY_ONCE

namespace ghostty {

HWY_EXPORT(DecodeUTF8UntilControlSeq);

size_t DecodeUTF8UntilControlSeq(const uint8_t* HWY_RESTRICT input,
                                 size_t count,
                                 char32_t* output,
                                 size_t* output_count) {
  return HWY_DYNAMIC_DISPATCH(DecodeUTF8UntilControlSeq)(input, count, output,
                                                         output_count);
}

}  // namespace ghostty

extern "C" {

size_t ghostty_simd_decode_utf8_until_control_seq(const uint8_t* HWY_RESTRICT
                                                      input,
                                                  size_t count,
                                                  char32_t* output,
                                                  size_t* output_count) {
  return ghostty::DecodeUTF8UntilControlSeq(input, count, output, output_count);
}

}  // extern "C"

#endif  // HWY_ONCE