Skip to content

Commit

Permalink
feat: allow rules to be treesitter context aware
Browse files Browse the repository at this point in the history
When a rule is defined with the `:with_context()` method, has a
specified filetype, and is operating in a buffer with a treesitter
parser attached, the rule will only execute iff the treesitter language
at the cursor matches one of the filetypes specified in the initial rule
definition.

> If there are no specified filetypes, of there is no parser attached to
> the current buffer, the rule executes as normal
  • Loading branch information
kamalsacranie committed Jan 18, 2024
1 parent 9fd4118 commit 3534bfd
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
16 changes: 15 additions & 1 deletion lua/nvim-autopairs.lua
@@ -1,5 +1,6 @@
local log = require('nvim-autopairs._log')
local utils = require('nvim-autopairs.utils')
local ts_utils = require('nvim-autopairs.ts-utils')
local basic_rule = require('nvim-autopairs.rules.basic')
local api = vim.api
local highlighter = nil
Expand Down Expand Up @@ -398,11 +399,24 @@ M.autopairs_map = function(bufnr, char)
return char
end
local line = utils.text_get_current_line(bufnr)
local _, col = utils.get_cursor()
local row, col = utils.get_cursor()
local new_text = ''
local add_char = 1
local rules = M.get_buf_rules(bufnr)
for _, rule in pairs(rules) do
-- Rules executes as normal if no treesitter is attached to buffer or filetype not specified
if rule.filetypes and rule.is_context_aware and pcall(vim.treesitter.get_parser) then
local language_tree = ts_utils.get_language_tree_at_position({ row, col })
-- log.debug("cursor_position:" .. vim.inspect({ row, col }))
-- log.debug("language_tree:" .. vim.inspect(language_tree))
if language_tree then
local current_language_context = language_tree:lang()
-- log.debug("current_language_context:" .. current_language_context)
if not vim.tbl_contains(rule.filetypes, current_language_context) then
return char
end
end
end
if rule.start_pair then
if char:match('<.*>') then
new_text = line
Expand Down
7 changes: 7 additions & 0 deletions lua/nvim-autopairs/rule.lua
Expand Up @@ -13,6 +13,7 @@ local Cond = require('nvim-autopairs.conds')
--- @field is_multibyte boolean
--- @field is_endwise boolean only use on end_wise
--- @field is_undo boolean add break undo sequence
--- @field is_context_aware boolean only active in treesitter contexts specified in filetypes

local Rule = setmetatable({}, {
__call = function(self, ...)
Expand Down Expand Up @@ -50,6 +51,7 @@ function Rule.new(...)
is_regex = false,
is_multibyte = false,
end_pair_length = nil,
is_context_aware = false,
}, opt) or {}

---@param rule Rule
Expand Down Expand Up @@ -148,6 +150,11 @@ function Rule:set_end_pair_length(length)
return self
end

function Rule:with_context()
self.is_context_aware = true
return self
end

function Rule:with_move(cond)
if self.move_cond == nil then self.move_cond = {} end
table.insert(self.move_cond, cond)
Expand Down
12 changes: 12 additions & 0 deletions lua/nvim-autopairs/ts-utils.lua
@@ -1,6 +1,18 @@
local ts_get_node_text = vim.treesitter.get_node_text or vim.treesitter.query.get_node_text
local M = {}

--- Returns the language tree at the given position.
---@return LanguageTree
function M.get_language_tree_at_position(position)
local language_tree = vim.treesitter.get_parser()
language_tree:for_each_tree(function(_, tree)
if tree:contains(vim.tbl_flatten({ position, position })) then
language_tree = tree
end
end)
return language_tree
end

function M.get_tag_name(node)
local tag_name = nil
if node ~=nil then
Expand Down

0 comments on commit 3534bfd

Please sign in to comment.