Merge pull request #533 from jakubbortlik/fix/check-local-in-sync-before-commenting

feat: check local branch state when creating comments
This commit is contained in:
Harrison (Harry) Cramer
2026-03-17 20:37:34 -04:00
committed by GitHub
6 changed files with 70 additions and 46 deletions

View File

@@ -349,6 +349,10 @@ M.can_create_comment = function(must_be_visual)
return false return false
end end
if not git.check_current_branch_up_to_date_on_remote(vim.log.levels.ERROR) then
return false
end
-- Check we're in visual mode for code suggestions and multiline comments -- Check we're in visual mode for code suggestions and multiline comments
if must_be_visual and not u.check_visual_mode() then if must_be_visual and not u.check_visual_mode() then
return false return false

View File

@@ -55,6 +55,9 @@ end
---Makes API call to get the discussion data, stores it in the state, and calls the callback ---Makes API call to get the discussion data, stores it in the state, and calls the callback
---@param callback function|nil ---@param callback function|nil
M.load_discussions = function(callback) M.load_discussions = function(callback)
local git = require("gitlab.git")
local ahead, behind = git.get_ahead_behind(git.get_current_branch(), git.get_remote_branch())
state.ahead_behind = { ahead, behind }
state.discussion_tree.last_updated = nil state.discussion_tree.last_updated = nil
state.load_new_state("discussion_data", function(data) state.load_new_state("discussion_data", function(data)
if not state.DISCUSSION_DATA then if not state.DISCUSSION_DATA then

View File

@@ -94,6 +94,8 @@ local function content()
resolved_notes = resolved_notes, resolved_notes = resolved_notes,
non_resolvable_notes = non_resolvable_notes, non_resolvable_notes = non_resolvable_notes,
help_keymap = state.settings.keymaps.help, help_keymap = state.settings.keymaps.help,
ahead = state.ahead_behind[1],
behind = state.ahead_behind[2],
updated = updated, updated = updated,
} }
@@ -158,7 +160,7 @@ end
M.make_winbar = function(t) M.make_winbar = function(t)
local discussions_focused = M.current_view_type == "discussions" local discussions_focused = M.current_view_type == "discussions"
local discussion_text = add_drafts_and_resolvable( local discussion_text = add_drafts_and_resolvable(
"Inline Comments:", "Comments:",
t.resolvable_discussions, t.resolvable_discussions,
t.resolved_discussions, t.resolved_discussions,
t.inline_draft_notes, t.inline_draft_notes,
@@ -190,15 +192,18 @@ M.make_winbar = function(t)
local separator = "%#Comment#|" local separator = "%#Comment#|"
local end_section = "%=" local end_section = "%="
local updated = "%#Text#" .. t.updated local updated = "%#Text#" .. t.updated
local ahead_behind = M.get_ahead_behind(t.ahead, t.behind)
local help = "%#Comment#Help: " .. (t.help_keymap and t.help_keymap:gsub(" ", "<space>") .. " " or "unmapped") local help = "%#Comment#Help: " .. (t.help_keymap and t.help_keymap:gsub(" ", "<space>") .. " " or "unmapped")
return string.format( return string.format(
" %s %s %s %s %s %s %s %s %s %s %s", " %s %s %s %s %s %s %s %s %s %s %s %s %s",
discussion_text, discussion_text,
separator, separator,
notes_text, notes_text,
end_section, end_section,
updated, updated,
separator, separator,
ahead_behind,
separator,
sort_method, sort_method,
separator, separator,
mode, mode,
@@ -254,6 +259,16 @@ M.get_mode = function()
end end
end end
---@param ahead number|nil
---@param behind number|nil
M.get_ahead_behind = function(ahead, behind)
local a = ahead == nil and "?" or tostring(ahead)
local b = behind == nil and "?" or tostring(behind)
a = ((a == "?" or a == "0") and "%#Comment#" or "%#WarningMsg#") .. a
b = ((b == "?" or b == "0") and "%#Comment#" or "%#WarningMsg#") .. b
return a .. "" .. b .. ""
end
---Toggles the current view type (or sets it to `override`) and then updates the view. ---Toggles the current view type (or sets it to `override`) and then updates the view.
---@param override "discussions"|"notes" Defines the view type to select. ---@param override "discussions"|"notes" Defines the view type to select.
M.switch_view_type = function(override) M.switch_view_type = function(override)

View File

@@ -92,6 +92,8 @@
---@field resolved_notes number ---@field resolved_notes number
---@field non_resolvable_notes number ---@field non_resolvable_notes number
---@field help_keymap string ---@field help_keymap string
---@field ahead number|nil -- Number of commits local is ahead of remote
---@field behind number|nil -- Number of commits local is behind remote
---@field updated string ---@field updated string
--- ---
---@class SignTable ---@class SignTable

View File

@@ -74,14 +74,17 @@ M.fetch_remote_branch = function(remote_branch)
return true return true
end end
---Determines whether the tracking branch is ahead of or behind the current branch, and warns the user if so ---Determines whether the tracking branch is ahead of or behind the current branch and returns the
---@param current_branch string ---number of ahead and behind commits or nil values in case of errors.
---@param remote_branch string ---@param current_branch string|nil
---@param log_level number ---@param remote_branch string|nil
---@return boolean ---@return integer|nil ahead, integer|nil behind
M.get_ahead_behind = function(current_branch, remote_branch, log_level) M.get_ahead_behind = function(current_branch, remote_branch)
if current_branch == nil or remote_branch == nil then
return nil, nil
end
if not M.fetch_remote_branch(remote_branch) then if not M.fetch_remote_branch(remote_branch) then
return false return nil, nil
end end
local u = require("gitlab.utils") local u = require("gitlab.utils")
@@ -89,39 +92,16 @@ M.get_ahead_behind = function(current_branch, remote_branch, log_level)
run_system({ "git", "rev-list", "--left-right", "--count", current_branch .. "..." .. remote_branch }) run_system({ "git", "rev-list", "--left-right", "--count", current_branch .. "..." .. remote_branch })
if err ~= nil or result == nil then if err ~= nil or result == nil then
u.notify("Could not determine if branch is up-to-date: " .. err, vim.log.levels.ERROR) u.notify("Could not determine if branch is up-to-date: " .. err, vim.log.levels.ERROR)
return false return nil, nil
end end
local ahead, behind = result:match("(%d+)%s+(%d+)") local ahead, behind = result:match("(%d+)%s+(%d+)")
if ahead == nil or behind == nil then if ahead == nil or behind == nil then
u.notify("Error parsing ahead/behind information.", vim.log.levels.ERROR) u.notify("Error parsing ahead/behind information", vim.log.levels.ERROR)
return false return nil, nil
end end
ahead = tonumber(ahead) return tonumber(ahead), tonumber(behind)
behind = tonumber(behind)
if ahead > 0 and behind == 0 then
u.notify(string.format("There are local changes that haven't been pushed to %s yet", remote_branch), log_level)
return false
end
if behind > 0 and ahead == 0 then
u.notify(string.format("There are remote changes on %s that haven't been pulled yet", remote_branch), log_level)
return false
end
if ahead > 0 and behind > 0 then
u.notify(
string.format(
"Your branch and the remote %s have diverged. You need to pull, possibly rebase, and then push.",
remote_branch
),
log_level
)
return false
end
return true -- Checks passed, branch is up-to-date
end end
---Return the name of the current branch or nil if it can't be retrieved ---Return the name of the current branch or nil if it can't be retrieved
@@ -184,16 +164,35 @@ end
---@return boolean ---@return boolean
M.check_current_branch_up_to_date_on_remote = function(log_level) M.check_current_branch_up_to_date_on_remote = function(log_level)
local current_branch = M.get_current_branch() local current_branch = M.get_current_branch()
if current_branch == nil then
return false
end
local remote_branch = M.get_remote_branch() local remote_branch = M.get_remote_branch()
if remote_branch == nil then local ahead, behind = M.get_ahead_behind(current_branch, remote_branch)
if ahead == nil or behind == nil then
return false return false
end end
return M.get_ahead_behind(current_branch, remote_branch, log_level) local u = require("gitlab.utils")
if ahead > 0 and behind == 0 then
u.notify(string.format("There are local changes that haven't been pushed to %s yet", remote_branch), log_level)
return false
end
if behind > 0 and ahead == 0 then
u.notify(string.format("There are remote changes on %s that haven't been pulled yet", remote_branch), log_level)
return false
end
if ahead > 0 and behind > 0 then
u.notify(
string.format(
"Your branch and the remote %s have diverged. You need to pull, possibly rebase, and then push.",
remote_branch
),
log_level
)
return false
end
return true -- Checks passed, branch is up-to-date
end end
---Warns user if the current MR is in a bad state (closed, has conflicts, merged) ---Warns user if the current MR is in a bad state (closed, has conflicts, merged)

View File

@@ -3,16 +3,17 @@
-- This module is also responsible for ensuring that the state of the plugin -- This module is also responsible for ensuring that the state of the plugin
-- is valid via dependencies -- is valid via dependencies
local git = require("gitlab.git")
local u = require("gitlab.utils") local u = require("gitlab.utils")
local List = require("gitlab.utils.list") local List = require("gitlab.utils.list")
local M = {} local M = {
emoji_map = nil,
M.emoji_map = nil ahead_behind = { nil, nil },
}
---Returns a gitlab token, and a gitlab URL. Used to connect to gitlab. ---Returns a gitlab token, and a gitlab URL. Used to connect to gitlab.
---@return string|nil, string|nil, string|nil ---@return string|nil, string|nil, string|nil
M.default_auth_provider = function() M.default_auth_provider = function()
local git = require("gitlab.git")
local base_path, err = M.settings.config_path, nil local base_path, err = M.settings.config_path, nil
if base_path == nil then if base_path == nil then
base_path, err = git.base_dir() base_path, err = git.base_dir()