Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'vendor/tree-sitter/src/wasm_store.c')
-rw-r--r--vendor/tree-sitter/src/wasm_store.c1823
1 files changed, 1823 insertions, 0 deletions
diff --git a/vendor/tree-sitter/src/wasm_store.c b/vendor/tree-sitter/src/wasm_store.c
new file mode 100644
index 00000000..7137a7fb
--- /dev/null
+++ b/vendor/tree-sitter/src/wasm_store.c
@@ -0,0 +1,1823 @@
+#include "tree_sitter/api.h"
+#include "./parser.h"
+#include <stdint.h>
+
+#ifdef TREE_SITTER_FEATURE_WASM
+
+#include <wasmtime.h>
+#include <wasm.h>
+#include <string.h>
+#include "./alloc.h"
+#include "./array.h"
+#include "./atomic.h"
+#include "./language.h"
+#include "./lexer.h"
+#include "./wasm_store.h"
+#include "./wasm/wasm-stdlib.h"
+
+#define array_len(a) (sizeof(a) / sizeof(a[0]))
+
+// The following symbols from the C and C++ standard libraries are available
+// for external scanners to use.
+const char *STDLIB_SYMBOLS[] = {
+ #include "./stdlib-symbols.txt"
+};
+
+// The contents of the `dylink.0` custom section of a wasm module,
+// as specified by the current WebAssembly dynamic linking ABI proposal.
+typedef struct {
+ uint32_t memory_size;
+ uint32_t memory_align;
+ uint32_t table_size;
+ uint32_t table_align;
+} WasmDylinkInfo;
+
+// WasmLanguageId - A pointer used to identify a language. This language id is
+// reference-counted, so that its ownership can be shared between the language
+// itself and the instances of the language that are held in wasm stores.
+typedef struct {
+ volatile uint32_t ref_count;
+ volatile uint32_t is_language_deleted;
+} WasmLanguageId;
+
+// LanguageWasmModule - Additional data associated with a wasm-backed
+// `TSLanguage`. This data is read-only and does not reference a particular
+// wasm store, so it can be shared by all users of a `TSLanguage`. A pointer to
+// this is stored on the language itself.
+typedef struct {
+ volatile uint32_t ref_count;
+ WasmLanguageId *language_id;
+ wasmtime_module_t *module;
+ const char *name;
+ char *symbol_name_buffer;
+ char *field_name_buffer;
+ WasmDylinkInfo dylink_info;
+} LanguageWasmModule;
+
+// LanguageWasmInstance - Additional data associated with an instantiation of
+// a `TSLanguage` in a particular wasm store. The wasm store holds one of
+// these structs for each language that it has instantiated.
+typedef struct {
+ WasmLanguageId *language_id;
+ wasmtime_instance_t instance;
+ int32_t external_states_address;
+ int32_t lex_main_fn_index;
+ int32_t lex_keyword_fn_index;
+ int32_t scanner_create_fn_index;
+ int32_t scanner_destroy_fn_index;
+ int32_t scanner_serialize_fn_index;
+ int32_t scanner_deserialize_fn_index;
+ int32_t scanner_scan_fn_index;
+} LanguageWasmInstance;
+
+typedef struct {
+ uint32_t reset_heap;
+ uint32_t proc_exit;
+ uint32_t abort;
+ uint32_t assert_fail;
+ uint32_t notify_memory_growth;
+ uint32_t debug_message;
+ uint32_t at_exit;
+ uint32_t args_get;
+ uint32_t args_sizes_get;
+} BuiltinFunctionIndices;
+
+// TSWasmStore - A struct that allows a given `Parser` to use wasm-backed
+// languages. This struct is mutable, and can only be used by one parser at a
+// time.
+struct TSWasmStore {
+ wasm_engine_t *engine;
+ wasmtime_store_t *store;
+ wasmtime_table_t function_table;
+ wasmtime_memory_t memory;
+ TSLexer *current_lexer;
+ LanguageWasmInstance *current_instance;
+ Array(LanguageWasmInstance) language_instances;
+ uint32_t current_memory_offset;
+ uint32_t current_function_table_offset;
+ uint32_t *stdlib_fn_indices;
+ BuiltinFunctionIndices builtin_fn_indices;
+ wasmtime_global_t stack_pointer_global;
+ wasm_globaltype_t *const_i32_type;
+ bool has_error;
+ uint32_t lexer_address;
+ uint32_t serialization_buffer_address;
+};
+
+typedef Array(char) StringData;
+
+// LanguageInWasmMemory - The memory layout of a `TSLanguage` when compiled to
+// wasm32. This is used to copy static language data out of the wasm memory.
+typedef struct {
+ uint32_t version;
+ uint32_t symbol_count;
+ uint32_t alias_count;
+ uint32_t token_count;
+ uint32_t external_token_count;
+ uint32_t state_count;
+ uint32_t large_state_count;
+ uint32_t production_id_count;
+ uint32_t field_count;
+ uint16_t max_alias_sequence_length;
+ int32_t parse_table;
+ int32_t small_parse_table;
+ int32_t small_parse_table_map;
+ int32_t parse_actions;
+ int32_t symbol_names;
+ int32_t field_names;
+ int32_t field_map_slices;
+ int32_t field_map_entries;
+ int32_t symbol_metadata;
+ int32_t public_symbol_map;
+ int32_t alias_map;
+ int32_t alias_sequences;
+ int32_t lex_modes;
+ int32_t lex_fn;
+ int32_t keyword_lex_fn;
+ TSSymbol keyword_capture_token;
+ struct {
+ int32_t states;
+ int32_t symbol_map;
+ int32_t create;
+ int32_t destroy;
+ int32_t scan;
+ int32_t serialize;
+ int32_t deserialize;
+ } external_scanner;
+ int32_t primary_state_ids;
+} LanguageInWasmMemory;
+
+// LexerInWasmMemory - The memory layout of a `TSLexer` when compiled to wasm32.
+// This is used to copy mutable lexing state in and out of the wasm memory.
+typedef struct {
+ int32_t lookahead;
+ TSSymbol result_symbol;
+ int32_t advance;
+ int32_t mark_end;
+ int32_t get_column;
+ int32_t is_at_included_range_start;
+ int32_t eof;
+} LexerInWasmMemory;
+
+static volatile uint32_t NEXT_LANGUAGE_ID;
+
+// Linear memory layout:
+// [ <-- stack | stdlib statics | lexer | serialization_buffer | language statics --> | heap --> ]
+#define MAX_MEMORY_SIZE (128 * 1024 * 1024 / MEMORY_PAGE_SIZE)
+
+/************************
+ * WasmDylinkMemoryInfo
+ ***********************/
+
+static uint8_t read_u8(const uint8_t **p, const uint8_t *end) {
+ return *(*p)++;
+}
+
+static inline uint64_t read_uleb128(const uint8_t **p, const uint8_t *end) {
+ uint64_t value = 0;
+ unsigned shift = 0;
+ do {
+ if (*p == end) return UINT64_MAX;
+ value += (uint64_t)(**p & 0x7f) << shift;
+ shift += 7;
+ } while (*((*p)++) >= 128);
+ return value;
+}
+
+static bool wasm_dylink_info__parse(
+ const uint8_t *bytes,
+ size_t length,
+ WasmDylinkInfo *info
+) {
+ const uint8_t WASM_MAGIC_NUMBER[4] = {0, 'a', 's', 'm'};
+ const uint8_t WASM_VERSION[4] = {1, 0, 0, 0};
+ const uint8_t WASM_CUSTOM_SECTION = 0x0;
+ const uint8_t WASM_DYLINK_MEM_INFO = 0x1;
+
+ const uint8_t *p = bytes;
+ const uint8_t *end = bytes + length;
+
+ if (length < 8) return false;
+ if (memcmp(p, WASM_MAGIC_NUMBER, 4) != 0) return false;
+ p += 4;
+ if (memcmp(p, WASM_VERSION, 4) != 0) return false;
+ p += 4;
+
+ while (p < end) {
+ uint8_t section_id = read_u8(&p, end);
+ uint32_t section_length = read_uleb128(&p, end);
+ const uint8_t *section_end = p + section_length;
+ if (section_end > end) return false;
+
+ if (section_id == WASM_CUSTOM_SECTION) {
+ uint32_t name_length = read_uleb128(&p, section_end);
+ const uint8_t *name_end = p + name_length;
+ if (name_end > section_end) return false;
+
+ if (name_length == 8 && memcmp(p, "dylink.0", 8) == 0) {
+ p = name_end;
+ while (p < section_end) {
+ uint8_t subsection_type = read_u8(&p, section_end);
+ uint32_t subsection_size = read_uleb128(&p, section_end);
+ const uint8_t *subsection_end = p + subsection_size;
+ if (subsection_end > section_end) return false;
+ if (subsection_type == WASM_DYLINK_MEM_INFO) {
+ info->memory_size = read_uleb128(&p, subsection_end);
+ info->memory_align = read_uleb128(&p, subsection_end);
+ info->table_size = read_uleb128(&p, subsection_end);
+ info->table_align = read_uleb128(&p, subsection_end);
+ return true;
+ }
+ p = subsection_end;
+ }
+ }
+ }
+ p = section_end;
+ }
+ return false;
+}
+
+/*******************************************
+ * Native callbacks exposed to wasm modules
+ *******************************************/
+
+ static wasm_trap_t *callback__abort(
+ void *env,
+ wasmtime_caller_t* caller,
+ wasmtime_val_raw_t *args_and_results,
+ size_t args_and_results_len
+) {
+ return wasmtime_trap_new("wasm module called abort", 24);
+}
+
+static wasm_trap_t *callback__debug_message(
+ void *env,
+ wasmtime_caller_t* caller,
+ wasmtime_val_raw_t *args_and_results,
+ size_t args_and_results_len
+) {
+ wasmtime_context_t *context = wasmtime_caller_context(caller);
+ TSWasmStore *store = env;
+ assert(args_and_results_len == 2);
+ uint32_t string_address = args_and_results[0].i32;
+ uint32_t value = args_and_results[1].i32;
+ uint8_t *memory = wasmtime_memory_data(context, &store->memory);
+ printf("DEBUG: %s %u\n", &memory[string_address], value);
+ return NULL;
+}
+
+static wasm_trap_t *callback__noop(
+ void *env,
+ wasmtime_caller_t* caller,
+ wasmtime_val_raw_t *args_and_results,
+ size_t args_and_results_len
+) {
+ return NULL;
+}
+
+static wasm_trap_t *callback__lexer_advance(
+ void *env,
+ wasmtime_caller_t* caller,
+ wasmtime_val_raw_t *args_and_results,
+ size_t args_and_results_len
+) {
+ wasmtime_context_t *context = wasmtime_caller_context(caller);
+ assert(args_and_results_len == 2);
+
+ TSWasmStore *store = env;
+ TSLexer *lexer = store->current_lexer;
+ bool skip = args_and_results[1].i32;
+ lexer->advance(lexer, skip);
+
+ uint8_t *memory = wasmtime_memory_data(context, &store->memory);
+ memcpy(&memory[store->lexer_address], &lexer->lookahead, sizeof(lexer->lookahead));
+ return NULL;
+}
+
+static wasm_trap_t *callback__lexer_mark_end(
+ void *env,
+ wasmtime_caller_t* caller,
+ wasmtime_val_raw_t *args_and_results,
+ size_t args_and_results_len
+) {
+ TSWasmStore *store = env;
+ TSLexer *lexer = store->current_lexer;
+ lexer->mark_end(lexer);
+ return NULL;
+}
+
+static wasm_trap_t *callback__lexer_get_column(
+ void *env,
+ wasmtime_caller_t* caller,
+ wasmtime_val_raw_t *args_and_results,
+ size_t args_and_results_len
+) {
+ TSWasmStore *store = env;
+ TSLexer *lexer = store->current_lexer;
+ uint32_t result = lexer->get_column(lexer);
+ args_and_results[0].i32 = result;
+ return NULL;
+}
+
+static wasm_trap_t *callback__lexer_is_at_included_range_start(
+ void *env,
+ wasmtime_caller_t* caller,
+ wasmtime_val_raw_t *args_and_results,
+ size_t args_and_results_len
+) {
+ TSWasmStore *store = env;
+ TSLexer *lexer = store->current_lexer;
+ bool result = lexer->is_at_included_range_start(lexer);
+ args_and_results[0].i32 = result;
+ return NULL;
+}
+
+static wasm_trap_t *callback__lexer_eof(
+ void *env,
+ wasmtime_caller_t* caller,
+ wasmtime_val_raw_t *args_and_results,
+ size_t args_and_results_len
+) {
+ TSWasmStore *store = env;
+ TSLexer *lexer = store->current_lexer;
+ bool result = lexer->eof(lexer);
+ args_and_results[0].i32 = result;
+ return NULL;
+}
+
+typedef struct {
+ uint32_t *storage_location;
+ wasmtime_func_unchecked_callback_t callback;
+ wasm_functype_t *type;
+} FunctionDefinition;
+
+static void *copy(const void *data, size_t size) {
+ void *result = ts_malloc(size);
+ memcpy(result, data, size);
+ return result;
+}
+
+static void *copy_unsized_static_array(
+ const uint8_t *data,
+ int32_t start_address,
+ const int32_t all_addresses[],
+ size_t address_count
+) {
+ int32_t end_address = 0;
+ for (unsigned i = 0; i < address_count; i++) {
+ if (all_addresses[i] > start_address) {
+ if (!end_address || all_addresses[i] < end_address) {
+ end_address = all_addresses[i];
+ }
+ }
+ }
+
+ if (!end_address) return NULL;
+ size_t size = end_address - start_address;
+ void *result = ts_malloc(size);
+ memcpy(result, &data[start_address], size);
+ return result;
+}
+
+static void *copy_strings(
+ const uint8_t *data,
+ int32_t array_address,
+ size_t count,
+ StringData *string_data
+) {
+ const char **result = ts_malloc(count * sizeof(char *));
+ for (unsigned i = 0; i < count; i++) {
+ int32_t address;
+ memcpy(&address, &data[array_address + i * sizeof(address)], sizeof(address));
+ if (address == 0) {
+ result[i] = (const char *)-1;
+ } else {
+ const uint8_t *string = &data[address];
+ uint32_t len = strlen((const char *)string);
+ result[i] = (const char *)(uintptr_t)string_data->size;
+ array_extend(string_data, len + 1, string);
+ }
+ }
+ for (unsigned i = 0; i < count; i++) {
+ if (result[i] == (const char *)-1) {
+ result[i] = NULL;
+ } else {
+ result[i] = string_data->contents + (uintptr_t)result[i];
+ }
+ }
+ return result;
+}
+
+static bool name_eq(const wasm_name_t *name, const char *string) {
+ return strncmp(string, name->data, name->size) == 0;
+}
+
+static inline wasm_functype_t* wasm_functype_new_4_0(
+ wasm_valtype_t* p1,
+ wasm_valtype_t* p2,
+ wasm_valtype_t* p3,
+ wasm_valtype_t* p4
+) {
+ wasm_valtype_t* ps[4] = {p1, p2, p3, p4};
+ wasm_valtype_vec_t params, results;
+ wasm_valtype_vec_new(&params, 4, ps);
+ wasm_valtype_vec_new_empty(&results);
+ return wasm_functype_new(&params, &results);
+}
+
+#define format(output, ...) \
+ do { \
+ size_t message_length = snprintf((char *)NULL, 0, __VA_ARGS__); \
+ *output = ts_malloc(message_length + 1); \
+ snprintf(*output, message_length + 1, __VA_ARGS__); \
+ } while (0)
+
+WasmLanguageId *language_id_new() {
+ WasmLanguageId *self = ts_malloc(sizeof(WasmLanguageId));
+ self->is_language_deleted = false;
+ self->ref_count = 1;
+ return self;
+}
+
+WasmLanguageId *language_id_clone(WasmLanguageId *self) {
+ atomic_inc(&self->ref_count);
+ return self;
+}
+
+void language_id_delete(WasmLanguageId *self) {
+ if (atomic_dec(&self->ref_count) == 0) {
+ ts_free(self);
+ }
+}
+
+static wasmtime_extern_t get_builtin_extern(
+ wasmtime_table_t *table,
+ unsigned index
+) {
+ return (wasmtime_extern_t) {
+ .kind = WASMTIME_EXTERN_FUNC,
+ .of.func = (wasmtime_func_t) {
+ .store_id = table->store_id,
+ .index = index
+ }
+ };
+}
+
+static bool ts_wasm_store__provide_builtin_import(
+ TSWasmStore *self,
+ const wasm_name_t *import_name,
+ wasmtime_extern_t *import
+) {
+ wasmtime_error_t *error = NULL;
+ wasmtime_context_t *context = wasmtime_store_context(self->store);
+
+ // Dynamic linking parameters
+ if (name_eq(import_name, "__memory_base")) {
+ wasmtime_val_t value = WASM_I32_VAL(self->current_memory_offset);
+ wasmtime_global_t global;
+ error = wasmtime_global_new(context, self->const_i32_type, &value, &global);
+ assert(!error);
+ *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = global};
+ } else if (name_eq(import_name, "__table_base")) {
+ wasmtime_val_t value = WASM_I32_VAL(self->current_function_table_offset);
+ wasmtime_global_t global;
+ error = wasmtime_global_new(context, self->const_i32_type, &value, &global);
+ assert(!error);
+ *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = global};
+ } else if (name_eq(import_name, "__stack_pointer")) {
+ *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = self->stack_pointer_global};
+ } else if (name_eq(import_name, "__indirect_function_table")) {
+ *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_TABLE, .of.table = self->function_table};
+ } else if (name_eq(import_name, "memory")) {
+ *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_MEMORY, .of.memory = self->memory};
+ }
+
+ // Builtin functions
+ else if (name_eq(import_name, "__assert_fail")) {
+ *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.assert_fail);
+ } else if (name_eq(import_name, "__cxa_atexit")) {
+ *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.at_exit);
+ } else if (name_eq(import_name, "args_get")) {
+ *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.args_get);
+ } else if (name_eq(import_name, "args_sizes_get")) {
+ *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.args_sizes_get);
+ } else if (name_eq(import_name, "abort")) {
+ *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.abort);
+ } else if (name_eq(import_name, "proc_exit")) {
+ *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.proc_exit);
+ } else if (name_eq(import_name, "emscripten_notify_memory_growth")) {
+ *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.notify_memory_growth);
+ } else if (name_eq(import_name, "tree_sitter_debug_message")) {
+ *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.debug_message);
+ } else {
+ return false;
+ }
+
+ return true;
+}
+
+static bool ts_wasm_store__call_module_initializer(
+ TSWasmStore *self,
+ const wasm_name_t *export_name,
+ wasmtime_extern_t *export,
+ wasm_trap_t **trap
+) {
+ if (
+ name_eq(export_name, "_initialize") ||
+ name_eq(export_name, "__wasm_apply_data_relocs") ||
+ name_eq(export_name, "__wasm_call_ctors")
+ ) {
+ wasmtime_context_t *context = wasmtime_store_context(self->store);
+ wasmtime_func_t initialization_func = export->of.func;
+ wasmtime_error_t *error = wasmtime_func_call(context, &initialization_func, NULL, 0, NULL, 0, trap);
+ assert(!error);
+ return true;
+ } else {
+ return false;
+ }
+}
+
+TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) {
+ TSWasmStore *self = ts_calloc(1, sizeof(TSWasmStore));
+ wasmtime_store_t *store = wasmtime_store_new(engine, self, NULL);
+ wasmtime_context_t *context = wasmtime_store_context(store);
+ wasmtime_error_t *error = NULL;
+ wasm_trap_t *trap = NULL;
+ wasm_message_t message = WASM_EMPTY_VEC;
+ wasm_exporttype_vec_t export_types = WASM_EMPTY_VEC;
+ wasmtime_extern_t *imports = NULL;
+ wasmtime_module_t *stdlib_module = NULL;
+ wasm_memorytype_t *memory_type = NULL;
+ wasm_tabletype_t *table_type = NULL;
+
+ // Define functions called by scanners via function pointers on the lexer.
+ LexerInWasmMemory lexer = {
+ .lookahead = 0,
+ .result_symbol = 0,
+ };
+ FunctionDefinition lexer_definitions[] = {
+ {
+ (uint32_t *)&lexer.advance,
+ callback__lexer_advance,
+ wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32())
+ },
+ {
+ (uint32_t *)&lexer.mark_end,
+ callback__lexer_mark_end,
+ wasm_functype_new_1_0(wasm_valtype_new_i32())
+ },
+ {
+ (uint32_t *)&lexer.get_column,
+ callback__lexer_get_column,
+ wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())
+ },
+ {
+ (uint32_t *)&lexer.is_at_included_range_start,
+ callback__lexer_is_at_included_range_start,
+ wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())
+ },
+ {
+ (uint32_t *)&lexer.eof,
+ callback__lexer_eof,
+ wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())
+ },
+ };
+
+ // Define builtin functions that can be imported by scanners.
+ BuiltinFunctionIndices builtin_fn_indices;
+ FunctionDefinition builtin_definitions[] = {
+ {
+ &builtin_fn_indices.proc_exit,
+ callback__abort,
+ wasm_functype_new_1_0(wasm_valtype_new_i32())
+ },
+ {
+ &builtin_fn_indices.abort,
+ callback__abort,
+ wasm_functype_new_0_0()
+ },
+ {
+ &builtin_fn_indices.assert_fail,
+ callback__abort,
+ wasm_functype_new_4_0(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())
+ },
+ {
+ &builtin_fn_indices.notify_memory_growth,
+ callback__noop,
+ wasm_functype_new_1_0(wasm_valtype_new_i32())
+ },
+ {
+ &builtin_fn_indices.debug_message,
+ callback__debug_message,
+ wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32())
+ },
+ {
+ &builtin_fn_indices.at_exit,
+ callback__noop,
+ wasm_functype_new_3_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())
+ },
+ {
+ &builtin_fn_indices.args_get,
+ callback__noop,
+ wasm_functype_new_2_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())
+ },
+ {
+ &builtin_fn_indices.args_sizes_get,
+ callback__noop,
+ wasm_functype_new_2_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())
+ },
+ };
+
+ // Create all of the wasm functions.
+ unsigned builtin_definitions_len = array_len(builtin_definitions);
+ unsigned lexer_definitions_len = array_len(lexer_definitions);
+ for (unsigned i = 0; i < builtin_definitions_len; i++) {
+ FunctionDefinition *definition = &builtin_definitions[i];
+ wasmtime_func_t func;
+ wasmtime_func_new_unchecked(context, definition->type, definition->callback, self, NULL, &func);
+ *definition->storage_location = func.index;
+ wasm_functype_delete(definition->type);
+ }
+ for (unsigned i = 0; i < lexer_definitions_len; i++) {
+ FunctionDefinition *definition = &lexer_definitions[i];
+ wasmtime_func_t func;
+ wasmtime_func_new_unchecked(context, definition->type, definition->callback, self, NULL, &func);
+ *definition->storage_location = func.index;
+ wasm_functype_delete(definition->type);
+ }
+
+ // Compile the stdlib module.
+ error = wasmtime_module_new(engine, STDLIB_WASM, STDLIB_WASM_LEN, &stdlib_module);
+ if (error) {
+ wasmtime_error_message(error, &message);
+ wasm_error->kind = TSWasmErrorKindCompile;
+ format(
+ &wasm_error->message,
+ "failed to compile wasm stdlib: %.*s",
+ (int)message.size, message.data
+ );
+ goto error;
+ }
+
+ // Retrieve the stdlib module's imports.
+ wasm_importtype_vec_t import_types = WASM_EMPTY_VEC;
+ wasmtime_module_imports(stdlib_module, &import_types);
+
+ // Find the initial number of memory pages needed by the stdlib.
+ const wasm_memorytype_t *stdlib_memory_type;
+ for (unsigned i = 0; i < import_types.size; i++) {
+ wasm_importtype_t *import_type = import_types.data[i];
+ const wasm_name_t *import_name = wasm_importtype_name(import_type);
+ if (name_eq(import_name, "memory")) {
+ const wasm_externtype_t *type = wasm_importtype_type(import_type);
+ stdlib_memory_type = wasm_externtype_as_memorytype_const(type);
+ }
+ }
+ if (!stdlib_memory_type) {
+ wasm_error->kind = TSWasmErrorKindCompile;
+ format(
+ &wasm_error->message,
+ "wasm stdlib is missing the 'memory' import"
+ );
+ goto error;
+ }
+
+ // Initialize store's memory
+ uint64_t initial_memory_pages = wasmtime_memorytype_minimum(stdlib_memory_type);
+ wasm_limits_t memory_limits = {.min = initial_memory_pages, .max = MAX_MEMORY_SIZE};
+ memory_type = wasm_memorytype_new(&memory_limits);
+ wasmtime_memory_t memory;
+ error = wasmtime_memory_new(context, memory_type, &memory);
+ if (error) {
+ wasmtime_error_message(error, &message);
+ wasm_error->kind = TSWasmErrorKindAllocate;
+ format(
+ &wasm_error->message,
+ "failed to allocate wasm memory: %.*s",
+ (int)message.size, message.data
+ );
+ goto error;
+ }
+ wasm_memorytype_delete(memory_type);
+ memory_type = NULL;
+
+ // Initialize store's function table
+ wasm_limits_t table_limits = {.min = 1, .max = wasm_limits_max_default};
+ table_type = wasm_tabletype_new(wasm_valtype_new(WASM_FUNCREF), &table_limits);
+ wasmtime_val_t initializer = {.kind = WASMTIME_FUNCREF};
+ wasmtime_table_t function_table;
+ error = wasmtime_table_new(context, table_type, &initializer, &function_table);
+ if (error) {
+ wasmtime_error_message(error, &message);
+ wasm_error->kind = TSWasmErrorKindAllocate;
+ format(
+ &wasm_error->message,
+ "failed to allocate wasm table: %.*s",
+ (int)message.size, message.data
+ );
+ goto error;
+ }
+ wasm_tabletype_delete(table_type);
+ table_type = NULL;
+
+ unsigned stdlib_symbols_len = array_len(STDLIB_SYMBOLS);
+
+ // Define globals for the stack and heap start addresses.
+ wasm_globaltype_t *const_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_CONST);
+ wasm_globaltype_t *var_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_VAR);
+
+ wasmtime_val_t stack_pointer_value = WASM_I32_VAL(0);
+ wasmtime_global_t stack_pointer_global;
+ error = wasmtime_global_new(context, var_i32_type, &stack_pointer_value, &stack_pointer_global);
+ assert(!error);
+
+ *self = (TSWasmStore) {
+ .engine = engine,
+ .store = store,
+ .memory = memory,
+ .function_table = function_table,
+ .language_instances = array_new(),
+ .stdlib_fn_indices = ts_calloc(stdlib_symbols_len, sizeof(uint32_t)),
+ .builtin_fn_indices = builtin_fn_indices,
+ .stack_pointer_global = stack_pointer_global,
+ .current_memory_offset = 0,
+ .current_function_table_offset = 0,
+ .const_i32_type = const_i32_type,
+ };
+
+ // Set up the imports for the stdlib module.
+ imports = ts_calloc(import_types.size, sizeof(wasmtime_extern_t));
+ for (unsigned i = 0; i < import_types.size; i++) {
+ wasm_importtype_t *type = import_types.data[i];
+ const wasm_name_t *import_name = wasm_importtype_name(type);
+ if (!ts_wasm_store__provide_builtin_import(self, import_name, &imports[i])) {
+ wasm_error->kind = TSWasmErrorKindInstantiate;
+ format(
+ &wasm_error->message,
+ "unexpected import in wasm stdlib: %.*s\n",
+ (int)import_name->size, import_name->data
+ );
+ goto error;
+ }
+ }
+
+ // Instantiate the stdlib module.
+ wasmtime_instance_t instance;
+ error = wasmtime_instance_new(context, stdlib_module, imports, import_types.size, &instance, &trap);
+ ts_free(imports);
+ imports = NULL;
+ if (error) {
+ wasmtime_error_message(error, &message);
+ wasm_error->kind = TSWasmErrorKindInstantiate;
+ format(
+ &wasm_error->message,
+ "failed to instantiate wasm stdlib module: %.*s",
+ (int)message.size, message.data
+ );
+ goto error;
+ }
+ if (trap) {
+ wasm_trap_message(trap, &message);
+ wasm_error->kind = TSWasmErrorKindInstantiate;
+ format(
+ &wasm_error->message,
+ "trapped when instantiating wasm stdlib module: %.*s",
+ (int)message.size, message.data
+ );
+ goto error;
+ }
+ wasm_importtype_vec_delete(&import_types);
+
+ // Process the stdlib module's exports.
+ for (unsigned i = 0; i < stdlib_symbols_len; i++) {
+ self->stdlib_fn_indices[i] = UINT32_MAX;
+ }
+ wasmtime_module_exports(stdlib_module, &export_types);
+ for (unsigned i = 0; i < export_types.size; i++) {
+ wasm_exporttype_t *export_type = export_types.data[i];
+ const wasm_name_t *name = wasm_exporttype_name(export_type);
+
+ char *export_name;
+ size_t name_len;
+ wasmtime_extern_t export = {.kind = WASM_EXTERN_GLOBAL};
+ bool exists = wasmtime_instance_export_nth(context, &instance, i, &export_name, &name_len, &export);
+ assert(exists);
+
+ if (export.kind == WASMTIME_EXTERN_GLOBAL) {
+ if (name_eq(name, "__stack_pointer")) {
+ self->stack_pointer_global = export.of.global;
+ }
+ }
+
+ if (export.kind == WASMTIME_EXTERN_FUNC) {
+ if (ts_wasm_store__call_module_initializer(self, name, &export, &trap)) {
+ if (trap) {
+ wasm_trap_message(trap, &message);
+ wasm_error->kind = TSWasmErrorKindInstantiate;
+ format(
+ &wasm_error->message,
+ "trap when calling stdlib relocation function: %.*s\n",
+ (int)message.size, message.data
+ );
+ goto error;
+ }
+ continue;
+ }
+
+ if (name_eq(name, "reset_heap")) {
+ self->builtin_fn_indices.reset_heap = export.of.func.index;
+ continue;
+ }
+
+ for (unsigned j = 0; j < stdlib_symbols_len; j++) {
+ if (name_eq(name, STDLIB_SYMBOLS[j])) {
+ self->stdlib_fn_indices[j] = export.of.func.index;
+ break;
+ }
+ }
+ }
+ }
+
+ if (self->builtin_fn_indices.reset_heap == UINT32_MAX) {
+ wasm_error->kind = TSWasmErrorKindInstantiate;
+ format(
+ &wasm_error->message,
+ "missing malloc reset function in wasm stdlib"
+ );
+ goto error;
+ }
+
+ for (unsigned i = 0; i < stdlib_symbols_len; i++) {
+ if (self->stdlib_fn_indices[i] == UINT32_MAX) {
+ wasm_error->kind = TSWasmErrorKindInstantiate;
+ format(
+ &wasm_error->message,
+ "missing exported symbol in wasm stdlib: %s",
+ STDLIB_SYMBOLS[i]
+ );
+ goto error;
+ }
+ }
+
+ wasm_exporttype_vec_delete(&export_types);
+ wasmtime_module_delete(stdlib_module);
+
+ // Add all of the lexer callback functions to the function table. Store their function table
+ // indices on the in-memory lexer.
+ uint32_t table_index;
+ error = wasmtime_table_grow(context, &function_table, lexer_definitions_len, &initializer, &table_index);
+ if (error) {
+ wasmtime_error_message(error, &message);
+ wasm_error->kind = TSWasmErrorKindAllocate;
+ format(
+ &wasm_error->message,
+ "failed to grow wasm table to initial size: %.*s",
+ (int)message.size, message.data
+ );
+ goto error;
+ }
+ for (unsigned i = 0; i < lexer_definitions_len; i++) {
+ FunctionDefinition *definition = &lexer_definitions[i];
+ wasmtime_func_t func = {function_table.store_id, *definition->storage_location};
+ wasmtime_val_t func_val = {.kind = WASMTIME_FUNCREF, .of.funcref = func};
+ error = wasmtime_table_set(context, &function_table, table_index, &func_val);
+ assert(!error);
+ *(int32_t *)(definition->storage_location) = table_index;
+ table_index++;
+ }
+
+ self->current_function_table_offset = table_index;
+ self->lexer_address = initial_memory_pages * MEMORY_PAGE_SIZE;
+ self->serialization_buffer_address = self->lexer_address + sizeof(LexerInWasmMemory);
+ self->current_memory_offset = self->serialization_buffer_address + TREE_SITTER_SERIALIZATION_BUFFER_SIZE;
+
+ // Grow the memory enough to hold the builtin lexer and serialization buffer.
+ uint32_t new_pages_needed = (self->current_memory_offset - self->lexer_address - 1) / MEMORY_PAGE_SIZE + 1;
+ uint64_t prev_memory_size;
+ wasmtime_memory_grow(context, &memory, new_pages_needed, &prev_memory_size);
+
+ uint8_t *memory_data = wasmtime_memory_data(context, &memory);
+ memcpy(&memory_data[self->lexer_address], &lexer, sizeof(lexer));
+ return self;
+
+error:
+ ts_free(self);
+ if (stdlib_module) wasmtime_module_delete(stdlib_module);
+ if (store) wasmtime_store_delete(store);
+ if (import_types.size) wasm_importtype_vec_delete(&import_types);
+ if (memory_type) wasm_memorytype_delete(memory_type);
+ if (table_type) wasm_tabletype_delete(table_type);
+ if (trap) wasm_trap_delete(trap);
+ if (error) wasmtime_error_delete(error);
+ if (message.size) wasm_byte_vec_delete(&message);
+ if (export_types.size) wasm_exporttype_vec_delete(&export_types);
+ if (imports) ts_free(imports);
+ return NULL;
+}
+
+void ts_wasm_store_delete(TSWasmStore *self) {
+ if (!self) return;
+ ts_free(self->stdlib_fn_indices);
+ wasm_globaltype_delete(self->const_i32_type);
+ wasmtime_store_delete(self->store);
+ wasm_engine_delete(self->engine);
+ for (unsigned i = 0; i < self->language_instances.size; i++) {
+ LanguageWasmInstance *instance = &self->language_instances.contents[i];
+ language_id_delete(instance->language_id);
+ }
+ array_delete(&self->language_instances);
+ ts_free(self);
+}
+
+size_t ts_wasm_store_language_count(const TSWasmStore *self) {
+ size_t result = 0;
+ for (unsigned i = 0; i < self->language_instances.size; i++) {
+ const WasmLanguageId *id = self->language_instances.contents[i].language_id;
+ if (!id->is_language_deleted) {
+ result++;
+ }
+ }
+ return result;
+}
+
+static bool ts_wasm_store__instantiate(
+ TSWasmStore *self,
+ wasmtime_module_t *module,
+ const char *language_name,
+ const WasmDylinkInfo *dylink_info,
+ wasmtime_instance_t *result,
+ int32_t *language_address,
+ char **error_message
+) {
+ wasmtime_error_t *error = NULL;
+ wasm_trap_t *trap = NULL;
+ wasm_message_t message = WASM_EMPTY_VEC;
+ char *language_function_name = NULL;
+ wasmtime_extern_t *imports = NULL;
+ wasmtime_context_t *context = wasmtime_store_context(self->store);
+
+ // Grow the function table to make room for the new functions.
+ wasmtime_val_t initializer = {.kind = WASMTIME_FUNCREF};
+ uint32_t prev_table_size;
+ error = wasmtime_table_grow(context, &self->function_table, dylink_info->table_size, &initializer, &prev_table_size);
+ if (error) {
+ format(error_message, "invalid function table size %u", dylink_info->table_size);
+ goto error;
+ }
+
+ // Grow the memory to make room for the new data.
+ uint32_t needed_memory_size = self->current_memory_offset + dylink_info->memory_size;
+ uint32_t current_memory_size = wasmtime_memory_data_size(context, &self->memory);
+ if (needed_memory_size > current_memory_size) {
+ uint32_t pages_to_grow = (
+ needed_memory_size - current_memory_size + MEMORY_PAGE_SIZE - 1) /
+ MEMORY_PAGE_SIZE;
+ uint64_t prev_memory_size;
+ error = wasmtime_memory_grow(context, &self->memory, pages_to_grow, &prev_memory_size);
+ if (error) {
+ format(error_message, "invalid memory size %u", dylink_info->memory_size);
+ goto error;
+ }
+ }
+
+ // Construct the language function name as string.
+ format(&language_function_name, "tree_sitter_%s", language_name);
+
+ const uint64_t store_id = self->function_table.store_id;
+
+ // Build the imports list for the module.
+ wasm_importtype_vec_t import_types = WASM_EMPTY_VEC;
+ wasmtime_module_imports(module, &import_types);
+ imports = ts_calloc(import_types.size, sizeof(wasmtime_extern_t));
+
+ for (unsigned i = 0; i < import_types.size; i++) {
+ const wasm_importtype_t *import_type = import_types.data[i];
+ const wasm_name_t *import_name = wasm_importtype_name(import_type);
+ if (import_name->size == 0) {
+ format(error_message, "empty import name");
+ goto error;
+ }
+
+ if (ts_wasm_store__provide_builtin_import(self, import_name, &imports[i])) {
+ continue;
+ }
+
+ bool defined_in_stdlib = false;
+ for (unsigned j = 0; j < array_len(STDLIB_SYMBOLS); j++) {
+ if (name_eq(import_name, STDLIB_SYMBOLS[j])) {
+ uint16_t address = self->stdlib_fn_indices[j];
+ imports[i] = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_FUNC, .of.func = {store_id, address}};
+ defined_in_stdlib = true;
+ break;
+ }
+ }
+
+ if (!defined_in_stdlib) {
+ format(
+ error_message,
+ "invalid import '%.*s'\n",
+ (int)import_name->size, import_name->data
+ );
+ goto error;
+ }
+ }
+
+ wasmtime_instance_t instance;
+ error = wasmtime_instance_new(context, module, imports, import_types.size, &instance, &trap);
+ wasm_importtype_vec_delete(&import_types);
+ ts_free(imports);
+ imports = NULL;
+ if (error) {
+ wasmtime_error_message(error, &message);
+ format(
+ error_message,
+ "error instantiating wasm module: %.*s\n",
+ (int)message.size, message.data
+ );
+ goto error;
+ }
+ if (trap) {
+ wasm_trap_message(trap, &message);
+ format(
+ error_message,
+ "trap when instantiating wasm module: %.*s\n",
+ (int)message.size, message.data
+ );
+ goto error;
+ }
+
+ self->current_memory_offset += dylink_info->memory_size;
+ self->current_function_table_offset += dylink_info->table_size;
+
+ // Process the module's exports.
+ bool found_language = false;
+ wasmtime_extern_t language_extern;
+ wasm_exporttype_vec_t export_types = WASM_EMPTY_VEC;
+ wasmtime_module_exports(module, &export_types);
+ for (unsigned i = 0; i < export_types.size; i++) {
+ wasm_exporttype_t *export_type = export_types.data[i];
+ const wasm_name_t *name = wasm_exporttype_name(export_type);
+
+ size_t name_len;
+ char *export_name;
+ wasmtime_extern_t export = {.kind = WASM_EXTERN_GLOBAL};
+ bool exists = wasmtime_instance_export_nth(context, &instance, i, &export_name, &name_len, &export);
+ assert(exists);
+
+ // If the module exports an initialization or data-relocation function, call it.
+ if (ts_wasm_store__call_module_initializer(self, name, &export, &trap)) {
+ if (trap) {
+ wasm_trap_message(trap, &message);
+ format(
+ error_message,
+ "trap when calling data relocation function: %.*s\n",
+ (int)message.size, message.data
+ );
+ goto error;
+ }
+ }
+
+ // Find the main language function for the module.
+ else if (name_eq(name, language_function_name)) {
+ language_extern = export;
+ found_language = true;
+ }
+ }
+ wasm_exporttype_vec_delete(&export_types);
+
+ if (!found_language) {
+ format(
+ error_message,
+ "module did not contain language function: %s",
+ language_function_name
+ );
+ goto error;
+ }
+
+ // Invoke the language function to get the static address of the language object.
+ wasmtime_func_t language_func = language_extern.of.func;
+ wasmtime_val_t language_address_val;
+ error = wasmtime_func_call(context, &language_func, NULL, 0, &language_address_val, 1, &trap);
+ assert(!error);
+ if (trap) {
+ wasm_trap_message(trap, &message);
+ format(
+ error_message,
+ "trapped when calling language function: %s: %.*s\n",
+ language_function_name, (int)message.size, message.data
+ );
+ goto error;
+ }
+
+ if (language_address_val.kind != WASMTIME_I32) {
+ format(
+ error_message,
+ "language function did not return an integer: %s\n",
+ language_function_name
+ );
+ goto error;
+ }
+
+ ts_free(language_function_name);
+ *result = instance;
+ *language_address = language_address_val.of.i32;
+ return true;
+
+error:
+ if (language_function_name) ts_free(language_function_name);
+ if (message.size) wasm_byte_vec_delete(&message);
+ if (error) wasmtime_error_delete(error);
+ if (trap) wasm_trap_delete(trap);
+ if (imports) ts_free(imports);
+ return false;
+}
+
+static bool ts_wasm_store__sentinel_lex_fn(TSLexer *_lexer, TSStateId state) {
+ return false;
+}
+
+const TSLanguage *ts_wasm_store_load_language(
+ TSWasmStore *self,
+ const char *language_name,
+ const char *wasm,
+ uint32_t wasm_len,
+ TSWasmError *wasm_error
+) {
+ WasmDylinkInfo dylink_info;
+ wasmtime_module_t *module = NULL;
+ wasmtime_error_t *error = NULL;
+ wasm_error->kind = TSWasmErrorKindNone;
+
+ if (!wasm_dylink_info__parse((const unsigned char *)wasm, wasm_len, &dylink_info)) {
+ wasm_error->kind = TSWasmErrorKindParse;
+ format(&wasm_error->message, "failed to parse dylink section of wasm module");
+ goto error;
+ }
+
+ // Compile the wasm code.
+ error = wasmtime_module_new(self->engine, (const uint8_t *)wasm, wasm_len, &module);
+ if (error) {
+ wasm_message_t message;
+ wasmtime_error_message(error, &message);
+ wasm_error->kind = TSWasmErrorKindCompile;
+ format(&wasm_error->message, "error compiling wasm module: %.*s", (int)message.size, message.data);
+ wasm_byte_vec_delete(&message);
+ goto error;
+ }
+
+ // Instantiate the module in this store.
+ wasmtime_instance_t instance;
+ int32_t language_address;
+ if (!ts_wasm_store__instantiate(
+ self,
+ module,
+ language_name,
+ &dylink_info,
+ &instance,
+ &language_address,
+ &wasm_error->message
+ )) {
+ wasm_error->kind = TSWasmErrorKindInstantiate;
+ goto error;
+ }
+
+ // Copy all of the static data out of the language object in wasm memory,
+ // constructing a native language object.
+ LanguageInWasmMemory wasm_language;
+ wasmtime_context_t *context = wasmtime_store_context(self->store);
+ const uint8_t *memory = wasmtime_memory_data(context, &self->memory);
+ memcpy(&wasm_language, &memory[language_address], sizeof(LanguageInWasmMemory));
+
+ if (wasm_language.version < LANGUAGE_VERSION_USABLE_VIA_WASM) {
+ wasm_error->kind = TSWasmErrorKindInstantiate;
+ format(&wasm_error->message, "language version %u is too old for wasm", wasm_language.version);
+ goto error;
+ }
+
+ int32_t addresses[] = {
+ wasm_language.alias_map,
+ wasm_language.alias_sequences,
+ wasm_language.field_map_entries,
+ wasm_language.field_map_slices,
+ wasm_language.field_names,
+ wasm_language.keyword_lex_fn,
+ wasm_language.lex_fn,
+ wasm_language.lex_modes,
+ wasm_language.parse_actions,
+ wasm_language.parse_table,
+ wasm_language.primary_state_ids,
+ wasm_language.primary_state_ids,
+ wasm_language.public_symbol_map,
+ wasm_language.small_parse_table,
+ wasm_language.small_parse_table_map,
+ wasm_language.symbol_metadata,
+ wasm_language.symbol_metadata,
+ wasm_language.symbol_names,
+ wasm_language.external_token_count > 0 ? wasm_language.external_scanner.states : 0,
+ wasm_language.external_token_count > 0 ? wasm_language.external_scanner.symbol_map : 0,
+ wasm_language.external_token_count > 0 ? wasm_language.external_scanner.create : 0,
+ wasm_language.external_token_count > 0 ? wasm_language.external_scanner.destroy : 0,
+ wasm_language.external_token_count > 0 ? wasm_language.external_scanner.scan : 0,
+ wasm_language.external_token_count > 0 ? wasm_language.external_scanner.serialize : 0,
+ wasm_language.external_token_count > 0 ? wasm_language.external_scanner.deserialize : 0,
+ language_address,
+ self->current_memory_offset,
+ };
+ uint32_t address_count = array_len(addresses);
+
+ TSLanguage *language = ts_calloc(1, sizeof(TSLanguage));
+ StringData symbol_name_buffer = array_new();
+ StringData field_name_buffer = array_new();
+
+ *language = (TSLanguage) {
+ .version = wasm_language.version,
+ .symbol_count = wasm_language.symbol_count,
+ .alias_count = wasm_language.alias_count,
+ .token_count = wasm_language.token_count,
+ .external_token_count = wasm_language.external_token_count,
+ .state_count = wasm_language.state_count,
+ .large_state_count = wasm_language.large_state_count,
+ .production_id_count = wasm_language.production_id_count,
+ .field_count = wasm_language.field_count,
+ .max_alias_sequence_length = wasm_language.max_alias_sequence_length,
+ .keyword_capture_token = wasm_language.keyword_capture_token,
+ .parse_table = copy(
+ &memory[wasm_language.parse_table],
+ wasm_language.large_state_count * wasm_language.symbol_count * sizeof(uint16_t)
+ ),
+ .parse_actions = copy_unsized_static_array(
+ memory,
+ wasm_language.parse_actions,
+ addresses,
+ address_count
+ ),
+ .symbol_names = copy_strings(
+ memory,
+ wasm_language.symbol_names,
+ wasm_language.symbol_count + wasm_language.alias_count,
+ &symbol_name_buffer
+ ),
+ .symbol_metadata = copy(
+ &memory[wasm_language.symbol_metadata],
+ (wasm_language.symbol_count + wasm_language.alias_count) * sizeof(TSSymbolMetadata)
+ ),
+ .public_symbol_map = copy(
+ &memory[wasm_language.public_symbol_map],
+ (wasm_language.symbol_count + wasm_language.alias_count) * sizeof(TSSymbol)
+ ),
+ .lex_modes = copy(
+ &memory[wasm_language.lex_modes],
+ wasm_language.state_count * sizeof(TSLexMode)
+ ),
+ };
+
+ if (language->field_count > 0 && language->production_id_count > 0) {
+ language->field_map_slices = copy(
+ &memory[wasm_language.field_map_slices],
+ wasm_language.production_id_count * sizeof(TSFieldMapSlice)
+ );
+ const TSFieldMapSlice last_field_map_slice = language->field_map_slices[language->production_id_count - 1];
+ language->field_map_entries = copy(
+ &memory[wasm_language.field_map_entries],
+ (last_field_map_slice.index + last_field_map_slice.length) * sizeof(TSFieldMapEntry)
+ );
+ language->field_names = copy_strings(
+ memory,
+ wasm_language.field_names,
+ wasm_language.field_count + 1,
+ &field_name_buffer
+ );
+ }
+
+ if (language->max_alias_sequence_length > 0 && language->production_id_count > 0) {
+ // The alias map contains symbols, alias counts, and aliases, terminated by a null symbol.
+ int32_t alias_map_size = 0;
+ for (;;) {
+ TSSymbol symbol;
+ memcpy(&symbol, &memory[wasm_language.alias_map + alias_map_size], sizeof(symbol));
+ alias_map_size += sizeof(TSSymbol);
+ if (symbol == 0) break;
+ uint16_t value_count;
+ memcpy(&value_count, &memory[wasm_language.alias_map + alias_map_size], sizeof(value_count));
+ alias_map_size += value_count * sizeof(TSSymbol);
+ }
+ language->alias_map = copy(
+ &memory[wasm_language.alias_map],
+ alias_map_size * sizeof(TSSymbol)
+ );
+ language->alias_sequences = copy(
+ &memory[wasm_language.alias_sequences],
+ wasm_language.production_id_count * wasm_language.max_alias_sequence_length * sizeof(TSSymbol)
+ );
+ }
+
+ if (language->state_count > language->large_state_count) {
+ uint32_t small_state_count = wasm_language.state_count - wasm_language.large_state_count;
+ language->small_parse_table_map = copy(
+ &memory[wasm_language.small_parse_table_map],
+ small_state_count * sizeof(uint32_t)
+ );
+ language->small_parse_table = copy_unsized_static_array(
+ memory,
+ wasm_language.small_parse_table,
+ addresses,
+ address_count
+ );
+ }
+
+ if (language->version >= LANGUAGE_VERSION_WITH_PRIMARY_STATES) {
+ language->primary_state_ids = copy(
+ &memory[wasm_language.primary_state_ids],
+ wasm_language.state_count * sizeof(TSStateId)
+ );
+ }
+
+ if (language->external_token_count > 0) {
+ language->external_scanner.symbol_map = copy(
+ &memory[wasm_language.external_scanner.symbol_map],
+ wasm_language.external_token_count * sizeof(TSSymbol)
+ );
+ language->external_scanner.states = (void *)(uintptr_t)wasm_language.external_scanner.states;
+ }
+
+ unsigned name_len = strlen(language_name);
+ char *name = ts_malloc(name_len + 1);
+ memcpy(name, language_name, name_len);
+ name[name_len] = '\0';
+
+ LanguageWasmModule *language_module = ts_malloc(sizeof(LanguageWasmModule));
+ *language_module = (LanguageWasmModule) {
+ .language_id = language_id_new(),
+ .module = module,
+ .name = name,
+ .symbol_name_buffer = symbol_name_buffer.contents,
+ .field_name_buffer = field_name_buffer.contents,
+ .dylink_info = dylink_info,
+ .ref_count = 1,
+ };
+
+ // The lex functions are not used for wasm languages. Use those two fields
+ // to mark this language as WASM-based and to store the language's
+ // WASM-specific data.
+ language->lex_fn = ts_wasm_store__sentinel_lex_fn;
+ language->keyword_lex_fn = (void *)language_module;
+
+ // Clear out any instances of languages that have been deleted.
+ for (unsigned i = 0; i < self->language_instances.size; i++) {
+ WasmLanguageId *id = self->language_instances.contents[i].language_id;
+ if (id->is_language_deleted) {
+ language_id_delete(id);
+ array_erase(&self->language_instances, i);
+ i--;
+ }
+ }
+
+ // Store this store's instance of this language module.
+ array_push(&self->language_instances, ((LanguageWasmInstance) {
+ .language_id = language_id_clone(language_module->language_id),
+ .instance = instance,
+ .external_states_address = wasm_language.external_scanner.states,
+ .lex_main_fn_index = wasm_language.lex_fn,
+ .lex_keyword_fn_index = wasm_language.keyword_lex_fn,
+ .scanner_create_fn_index = wasm_language.external_scanner.create,
+ .scanner_destroy_fn_index = wasm_language.external_scanner.destroy,
+ .scanner_serialize_fn_index = wasm_language.external_scanner.serialize,
+ .scanner_deserialize_fn_index = wasm_language.external_scanner.deserialize,
+ .scanner_scan_fn_index = wasm_language.external_scanner.scan,
+ }));
+
+ return language;
+
+error:
+ if (module) wasmtime_module_delete(module);
+ return NULL;
+}
+
+bool ts_wasm_store_add_language(
+ TSWasmStore *self,
+ const TSLanguage *language,
+ uint32_t *index
+) {
+ wasmtime_context_t *context = wasmtime_store_context(self->store);
+ const LanguageWasmModule *language_module = (void *)language->keyword_lex_fn;
+
+ // Search for this store's instance of the language module. Also clear out any
+ // instances of languages that have been deleted.
+ bool exists = false;
+ for (unsigned i = 0; i < self->language_instances.size; i++) {
+ WasmLanguageId *id = self->language_instances.contents[i].language_id;
+ if (id->is_language_deleted) {
+ language_id_delete(id);
+ array_erase(&self->language_instances, i);
+ i--;
+ } else if (id == language_module->language_id) {
+ exists = true;
+ *index = i;
+ }
+ }
+
+ // If the language module has not been instantiated in this store, then add
+ // it to this store.
+ if (!exists) {
+ *index = self->language_instances.size;
+ char *message;
+ wasmtime_instance_t instance;
+ int32_t language_address;
+ if (!ts_wasm_store__instantiate(
+ self,
+ language_module->module,
+ language_module->name,
+ &language_module->dylink_info,
+ &instance,
+ &language_address,
+ &message
+ )) {
+ ts_free(message);
+ return false;
+ }
+
+ LanguageInWasmMemory wasm_language;
+ const uint8_t *memory = wasmtime_memory_data(context, &self->memory);
+ memcpy(&wasm_language, &memory[language_address], sizeof(LanguageInWasmMemory));
+ array_push(&self->language_instances, ((LanguageWasmInstance) {
+ .language_id = language_id_clone(language_module->language_id),
+ .instance = instance,
+ .external_states_address = wasm_language.external_scanner.states,
+ .lex_main_fn_index = wasm_language.lex_fn,
+ .lex_keyword_fn_index = wasm_language.keyword_lex_fn,
+ .scanner_create_fn_index = wasm_language.external_scanner.create,
+ .scanner_destroy_fn_index = wasm_language.external_scanner.destroy,
+ .scanner_serialize_fn_index = wasm_language.external_scanner.serialize,
+ .scanner_deserialize_fn_index = wasm_language.external_scanner.deserialize,
+ .scanner_scan_fn_index = wasm_language.external_scanner.scan,
+ }));
+ }
+
+ return true;
+}
+
+void ts_wasm_store_reset_heap(TSWasmStore *self) {
+ wasmtime_context_t *context = wasmtime_store_context(self->store);
+ wasmtime_func_t func = {
+ self->function_table.store_id,
+ self->builtin_fn_indices.reset_heap
+ };
+ wasm_trap_t *trap = NULL;
+ wasmtime_val_t args[1] = {
+ {.of.i32 = self->current_memory_offset, .kind = WASMTIME_I32},
+ };
+
+ wasmtime_error_t *error = wasmtime_func_call(context, &func, args, 1, NULL, 0, &trap);
+ assert(!error);
+ assert(!trap);
+}
+
+bool ts_wasm_store_start(TSWasmStore *self, TSLexer *lexer, const TSLanguage *language) {
+ uint32_t instance_index;
+ if (!ts_wasm_store_add_language(self, language, &instance_index)) return false;
+ self->current_lexer = lexer;
+ self->current_instance = &self->language_instances.contents[instance_index];
+ self->has_error = false;
+ ts_wasm_store_reset_heap(self);
+ return true;
+}
+
+void ts_wasm_store_reset(TSWasmStore *self) {
+ self->current_lexer = NULL;
+ self->current_instance = NULL;
+ self->has_error = false;
+ ts_wasm_store_reset_heap(self);
+}
+
+static void ts_wasm_store__call(
+ TSWasmStore *self,
+ int32_t function_index,
+ wasmtime_val_raw_t *args_and_results,
+ size_t args_and_results_len
+) {
+ wasmtime_context_t *context = wasmtime_store_context(self->store);
+ wasmtime_val_t value;
+ bool succeeded = wasmtime_table_get(context, &self->function_table, function_index, &value);
+ assert(succeeded);
+ assert(value.kind == WASMTIME_FUNCREF);
+ wasmtime_func_t func = value.of.funcref;
+
+ wasm_trap_t *trap = NULL;
+ wasmtime_error_t *error = wasmtime_func_call_unchecked(context, &func, args_and_results, args_and_results_len, &trap);
+ if (error) {
+ // wasm_message_t message;
+ // wasmtime_error_message(error, &message);
+ // fprintf(
+ // stderr,
+ // "error in wasm module: %.*s\n",
+ // (int)message.size, message.data
+ // );
+ wasmtime_error_delete(error);
+ self->has_error = true;
+ } else if (trap) {
+ // wasm_message_t message;
+ // wasm_trap_message(trap, &message);
+ // fprintf(
+ // stderr,
+ // "trap in wasm module: %.*s\n",
+ // (int)message.size, message.data
+ // );
+ wasm_trap_delete(trap);
+ self->has_error = true;
+ }
+}
+
+static bool ts_wasm_store__call_lex_function(TSWasmStore *self, unsigned function_index, TSStateId state) {
+ wasmtime_context_t *context = wasmtime_store_context(self->store);
+ uint8_t *memory_data = wasmtime_memory_data(context, &self->memory);
+ memcpy(
+ &memory_data[self->lexer_address],
+ &self->current_lexer->lookahead,
+ sizeof(self->current_lexer->lookahead)
+ );
+
+ wasmtime_val_raw_t args[2] = {
+ {.i32 = self->lexer_address},
+ {.i32 = state},
+ };
+ ts_wasm_store__call(self, function_index, args, 2);
+ if (self->has_error) return false;
+ bool result = args[0].i32;
+
+ memcpy(
+ &self->current_lexer->lookahead,
+ &memory_data[self->lexer_address],
+ sizeof(self->current_lexer->lookahead) + sizeof(self->current_lexer->result_symbol)
+ );
+ return result;
+}
+
+bool ts_wasm_store_call_lex_main(TSWasmStore *self, TSStateId state) {
+ return ts_wasm_store__call_lex_function(
+ self,
+ self->current_instance->lex_main_fn_index,
+ state
+ );
+}
+
+bool ts_wasm_store_call_lex_keyword(TSWasmStore *self, TSStateId state) {
+ return ts_wasm_store__call_lex_function(
+ self,
+ self->current_instance->lex_keyword_fn_index,
+ state
+ );
+}
+
+uint32_t ts_wasm_store_call_scanner_create(TSWasmStore *self) {
+ wasmtime_val_raw_t args[1] = {{.i32 = 0}};
+ ts_wasm_store__call(self, self->current_instance->scanner_create_fn_index, args, 1);
+ if (self->has_error) return 0;
+ return args[0].i32;
+}
+
+void ts_wasm_store_call_scanner_destroy(TSWasmStore *self, uint32_t scanner_address) {
+ if (self->current_instance) {
+ wasmtime_val_raw_t args[1] = {{.i32 = scanner_address}};
+ ts_wasm_store__call(self, self->current_instance->scanner_destroy_fn_index, args, 1);
+ }
+}
+
+bool ts_wasm_store_call_scanner_scan(
+ TSWasmStore *self,
+ uint32_t scanner_address,
+ uint32_t valid_tokens_ix
+) {
+ wasmtime_context_t *context = wasmtime_store_context(self->store);
+ uint8_t *memory_data = wasmtime_memory_data(context, &self->memory);
+
+ memcpy(
+ &memory_data[self->lexer_address],
+ &self->current_lexer->lookahead,
+ sizeof(self->current_lexer->lookahead)
+ );
+
+ uint32_t valid_tokens_address =
+ self->current_instance->external_states_address +
+ (valid_tokens_ix * sizeof(bool));
+ wasmtime_val_raw_t args[3] = {
+ {.i32 = scanner_address},
+ {.i32 = self->lexer_address},
+ {.i32 = valid_tokens_address}
+ };
+ ts_wasm_store__call(self, self->current_instance->scanner_scan_fn_index, args, 3);
+ if (self->has_error) return false;
+
+ memcpy(
+ &self->current_lexer->lookahead,
+ &memory_data[self->lexer_address],
+ sizeof(self->current_lexer->lookahead) + sizeof(self->current_lexer->result_symbol)
+ );
+ return args[0].i32;
+}
+
+uint32_t ts_wasm_store_call_scanner_serialize(
+ TSWasmStore *self,
+ uint32_t scanner_address,
+ char *buffer
+) {
+ wasmtime_context_t *context = wasmtime_store_context(self->store);
+ uint8_t *memory_data = wasmtime_memory_data(context, &self->memory);
+
+ wasmtime_val_raw_t args[2] = {
+ {.i32 = scanner_address},
+ {.i32 = self->serialization_buffer_address},
+ };
+ ts_wasm_store__call(self, self->current_instance->scanner_serialize_fn_index, args, 2);
+ if (self->has_error) return 0;
+
+ uint32_t length = args[0].i32;
+
+ if (length > 0) {
+ memcpy(
+ ((Lexer *)self->current_lexer)->debug_buffer,
+ &memory_data[self->serialization_buffer_address],
+ length
+ );
+ }
+ return length;
+}
+
+void ts_wasm_store_call_scanner_deserialize(
+ TSWasmStore *self,
+ uint32_t scanner_address,
+ const char *buffer,
+ unsigned length
+) {
+ wasmtime_context_t *context = wasmtime_store_context(self->store);
+ uint8_t *memory_data = wasmtime_memory_data(context, &self->memory);
+
+ if (length > 0) {
+ memcpy(
+ &memory_data[self->serialization_buffer_address],
+ buffer,
+ length
+ );
+ }
+
+ wasmtime_val_raw_t args[3] = {
+ {.i32 = scanner_address},
+ {.i32 = self->serialization_buffer_address},
+ {.i32 = length},
+ };
+ ts_wasm_store__call(self, self->current_instance->scanner_deserialize_fn_index, args, 3);
+}
+
+bool ts_wasm_store_has_error(const TSWasmStore *self) {
+ return self->has_error;
+}
+
+bool ts_language_is_wasm(const TSLanguage *self) {
+ return self->lex_fn == ts_wasm_store__sentinel_lex_fn;
+}
+
+static inline LanguageWasmModule *ts_language__wasm_module(const TSLanguage *self) {
+ return (LanguageWasmModule *)self->keyword_lex_fn;
+}
+
+void ts_wasm_language_retain(const TSLanguage *self) {
+ LanguageWasmModule *module = ts_language__wasm_module(self);
+ assert(module->ref_count > 0);
+ atomic_inc(&module->ref_count);
+}
+
+void ts_wasm_language_release(const TSLanguage *self) {
+ LanguageWasmModule *module = ts_language__wasm_module(self);
+ assert(module->ref_count > 0);
+ if (atomic_dec(&module->ref_count) == 0) {
+ // Update the language id to reflect that the language is deleted. This allows any wasm stores
+ // that hold wasm instances for this language to delete those instances.
+ atomic_inc(&module->language_id->is_language_deleted);
+ language_id_delete(module->language_id);
+
+ ts_free((void *)module->field_name_buffer);
+ ts_free((void *)module->symbol_name_buffer);
+ ts_free((void *)module->name);
+ wasmtime_module_delete(module->module);
+ ts_free(module);
+
+ ts_free((void *)self->alias_map);
+ ts_free((void *)self->alias_sequences);
+ ts_free((void *)self->external_scanner.symbol_map);
+ ts_free((void *)self->field_map_entries);
+ ts_free((void *)self->field_map_slices);
+ ts_free((void *)self->field_names);
+ ts_free((void *)self->lex_modes);
+ ts_free((void *)self->parse_actions);
+ ts_free((void *)self->parse_table);
+ ts_free((void *)self->primary_state_ids);
+ ts_free((void *)self->public_symbol_map);
+ ts_free((void *)self->small_parse_table);
+ ts_free((void *)self->small_parse_table_map);
+ ts_free((void *)self->symbol_metadata);
+ ts_free((void *)self->symbol_names);
+ ts_free((void *)self);
+ }
+}
+
+#else
+
+// If the WASM feature is not enabled, define dummy versions of all of the
+// wasm-related functions.
+
+void ts_wasm_store_delete(TSWasmStore *self) {
+ (void)self;
+}
+
+bool ts_wasm_store_start(
+ TSWasmStore *self,
+ TSLexer *lexer,
+ const TSLanguage *language
+) {
+ (void)self;
+ (void)lexer;
+ (void)language;
+ return false;
+}
+
+void ts_wasm_store_reset(TSWasmStore *self) {
+ (void)self;
+}
+
+bool ts_wasm_store_call_lex_main(TSWasmStore *self, TSStateId state) {
+ (void)self;
+ (void)state;
+ return false;
+}
+
+bool ts_wasm_store_call_lex_keyword(TSWasmStore *self, TSStateId state) {
+ (void)self;
+ (void)state;
+ return false;
+}
+
+uint32_t ts_wasm_store_call_scanner_create(TSWasmStore *self) {
+ (void)self;
+ return 0;
+}
+
+void ts_wasm_store_call_scanner_destroy(
+ TSWasmStore *self,
+ uint32_t scanner_address
+) {
+ (void)self;
+ (void)scanner_address;
+}
+
+bool ts_wasm_store_call_scanner_scan(
+ TSWasmStore *self,
+ uint32_t scanner_address,
+ uint32_t valid_tokens_ix
+) {
+ (void)self;
+ (void)scanner_address;
+ (void)valid_tokens_ix;
+ return false;
+}
+
+uint32_t ts_wasm_store_call_scanner_serialize(
+ TSWasmStore *self,
+ uint32_t scanner_address,
+ char *buffer
+) {
+ (void)self;
+ (void)scanner_address;
+ (void)buffer;
+ return 0;
+}
+
+void ts_wasm_store_call_scanner_deserialize(
+ TSWasmStore *self,
+ uint32_t scanner_address,
+ const char *buffer,
+ unsigned length
+) {
+ (void)self;
+ (void)scanner_address;
+ (void)buffer;
+ (void)length;
+}
+
+bool ts_wasm_store_has_error(const TSWasmStore *self) {
+ (void)self;
+ return false;
+}
+
+bool ts_language_is_wasm(const TSLanguage *self) {
+ (void)self;
+ return false;
+}
+
+void ts_wasm_language_retain(const TSLanguage *self) {
+ (void)self;
+}
+
+void ts_wasm_language_release(const TSLanguage *self) {
+ (void)self;
+}
+
+#endif