Initial commit: docker-compose-updater
Build and Push / build (push) Failing after 13m20s

Go 项目,包含:
- 服务端 updater:两阶段协议,ECDSA 签名验证,AES-GCM 加密
- 发送端 dcu-send:Gitea Action CLI
- internal/auth:加解密/签名/会话管理
- internal/docker:Docker CLI 容器查找/拉取/重建
- action/:Gitea Action 定义
- deploy/Dockerfile:多阶段构建
- .gitea/workflows/build.yaml:CI/CD
This commit is contained in:
ilovintit
2026-06-08 15:16:46 +08:00
commit cea9b941cf
21 changed files with 1874 additions and 0 deletions
+75
View File
@@ -0,0 +1,75 @@
// Package auth 提供加解密、签名验签和会话密钥管理。
// 发送端 (dcu-send) 和接收端 (updater) 共用此包。
package auth
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
)
// Encrypt 用 AES-256-GCM 加密明文,返回 nonce+ciphertext。
// key 必须是 32 字节。
func Encrypt(plaintext []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("aes new cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("aes gcm: %w", err)
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("nonce: %w", err)
}
// GCM 附加认证数据 (AAD) 传空
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
return ciphertext, nil
}
// Decrypt 用 AES-256-GCM 解密,输入为 nonce+ciphertext。
func Decrypt(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("aes new cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("aes gcm: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, fmt.Errorf("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("aes decrypt: %w", err)
}
return plaintext, nil
}
// DeriveKey 用 HKDF-SHA256 从 shared_secret 和 context 派生出 AES-256 密钥。
// salt: 每个请求的 nonce16 字节)
// info: 协议标识,如 "dcu-updater/v1"
func DeriveKey(sharedSecret []byte, salt []byte, info string) []byte {
hkdf := hkdf.New(sha256.New, sharedSecret, salt, []byte(info))
key := make([]byte, 32) // AES-256
if _, err := io.ReadFull(hkdf, key); err != nil {
// HKDF 使用 SHA-256,从不会返回错误
panic(fmt.Sprintf("hkdf read: %v", err))
}
return key
}
+39
View File
@@ -0,0 +1,39 @@
package auth
import (
"crypto/ecdsa"
"crypto/x509"
"encoding/pem"
"fmt"
)
// ParseECDSAPrivateKey 解析 PEM 编码的 ECDSA 私钥。
func ParseECDSAPrivateKey(pemData []byte) (*ecdsa.PrivateKey, error) {
block, _ := pem.Decode(pemData)
if block == nil || block.Type != "EC PRIVATE KEY" {
return nil, fmt.Errorf("invalid EC private key PEM")
}
key, err := x509.ParseECPrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse EC private key: %w", err)
}
return key, nil
}
// ParseECDSAPublicKey 解析 PEM 编码的 ECDSA 公钥。
func ParseECDSAPublicKey(pemData []byte) (*ecdsa.PublicKey, error) {
block, _ := pem.Decode(pemData)
if block == nil || block.Type != "PUBLIC KEY" {
return nil, fmt.Errorf("invalid public key PEM")
}
key, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse public key: %w", err)
}
pubKey, ok := key.(*ecdsa.PublicKey)
if !ok {
return nil, fmt.Errorf("key is not ECDSA")
}
return pubKey, nil
}
+136
View File
@@ -0,0 +1,136 @@
package auth
import (
"crypto/rand"
"fmt"
"sync"
"time"
)
// sessionEntry 存储一次 Phase 1 协商的会话密钥。
type sessionEntry struct {
key []byte // AES-256 会话密钥
expiry time.Time
}
// SessionManager 管理内存中的会话密钥(Nonce → AES Key)。
// 每个 Phase 1 请求生成一个 sessionPhase 2 消费后删除。
type SessionManager struct {
mu sync.Mutex
store map[string]*sessionEntry
ttl time.Duration
}
// NewSessionManager 创建会话密钥管理器。
// ttl: 密钥过期时间(如 30 秒)
func NewSessionManager(ttl time.Duration) *SessionManager {
sm := &SessionManager{
store: make(map[string]*sessionEntry),
ttl: ttl,
}
// 后台协程定期清理过期密钥
go sm.cleanupLoop()
return sm
}
// GenerateKey 生成 AES-256 会话密钥,用指定的 keyID 存储(通常用 nonce)。
func (sm *SessionManager) GenerateKey(keyID string) ([]byte, error) {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, fmt.Errorf("generate session key: %w", err)
}
sm.mu.Lock()
sm.store[keyID] = &sessionEntry{
key: key,
expiry: time.Now().Add(sm.ttl),
}
sm.mu.Unlock()
return key, nil
}
// GetKey 获取并删除会话密钥(一次性使用)。
// 返回 nil 表示 keyID 不存在或已过期。
func (sm *SessionManager) GetKey(keyID string) []byte {
sm.mu.Lock()
defer sm.mu.Unlock()
entry, ok := sm.store[keyID]
if !ok {
return nil
}
delete(sm.store, keyID)
if time.Now().After(entry.expiry) {
return nil
}
return entry.key
}
// cleanupLoop 定期清理过期密钥。
func (sm *SessionManager) cleanupLoop() {
ticker := time.NewTicker(sm.ttl)
defer ticker.Stop()
for range ticker.C {
sm.mu.Lock()
now := time.Now()
for id, entry := range sm.store {
if now.After(entry.expiry) {
delete(sm.store, id)
}
}
sm.mu.Unlock()
}
}
// NonceCache 用于防重放攻击的 Nonce 缓存。
type NonceCache struct {
mu sync.Mutex
store map[string]time.Time
ttl time.Duration
}
// NewNonceCache 创建 Nonce 缓存。
// ttl: Nonce 有效时间(如 60 秒)
func NewNonceCache(ttl time.Duration) *NonceCache {
nc := &NonceCache{
store: make(map[string]time.Time),
ttl: ttl,
}
go nc.cleanupLoop()
return nc
}
// Check 检查并记录 Nonce。返回 true 表示 Nonce 有效(未使用过)。
// 返回 false 表示 Nonce 已存在(重放攻击)。
func (nc *NonceCache) Check(nonce string) bool {
nc.mu.Lock()
defer nc.mu.Unlock()
if _, exists := nc.store[nonce]; exists {
return false
}
nc.store[nonce] = time.Now().Add(nc.ttl)
return true
}
func (nc *NonceCache) cleanupLoop() {
ticker := time.NewTicker(nc.ttl)
defer ticker.Stop()
for range ticker.C {
nc.mu.Lock()
now := time.Now()
for n, expiry := range nc.store {
if now.After(expiry) {
delete(nc.store, n)
}
}
nc.mu.Unlock()
}
}
+30
View File
@@ -0,0 +1,30 @@
package auth
import (
"crypto/ecdsa"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
)
// Sign ECDSA P-256 签名,返回 DER 编码的 Base64 签名。
func Sign(key *ecdsa.PrivateKey, data []byte) (string, error) {
hash := sha256.Sum256(data)
sig, err := ecdsa.SignASN1(rand.Reader, key, hash[:])
if err != nil {
return "", fmt.Errorf("ecdsa sign: %w", err)
}
return base64.StdEncoding.EncodeToString(sig), nil
}
// Verify 验证 ECDSA P-256 签名。
// sigBase64 是 DER 编码的 Base64 签名。
func Verify(key *ecdsa.PublicKey, data []byte, sigBase64 string) bool {
sig, err := base64.StdEncoding.DecodeString(sigBase64)
if err != nil {
return false
}
hash := sha256.Sum256(data)
return ecdsa.VerifyASN1(key, hash[:], sig)
}
+56
View File
@@ -0,0 +1,56 @@
// Package config 提供全局配置,全部通过环境变量设置。
package config
import (
"os"
"time"
)
// Config 应用配置。
type Config struct {
// 监听地址
Listen string
// Docker 操作超时
DockerPullTimeout time.Duration
DockerRestartTimeout time.Duration
// 会话密钥 TTL
SessionTTL time.Duration
// Nonce 缓存 TTL
NonceTTL time.Duration
// 时间戳容忍窗口
TimestampWindow time.Duration
// 日志级别
LogLevel string
}
// Load 从环境变量加载配置,未设置则用默认值。
func Load() *Config {
return &Config{
Listen: getEnv("LISTEN", ":8080"),
DockerPullTimeout: getDuration("DOCKER_PULL_TIMEOUT", 5*time.Minute),
DockerRestartTimeout: getDuration("DOCKER_RESTART_TIMEOUT", 30*time.Second),
SessionTTL: getDuration("SESSION_TTL", 30*time.Second),
NonceTTL: getDuration("NONCE_TTL", 60*time.Second),
TimestampWindow: getDuration("TIMESTAMP_WINDOW", 30*time.Second),
LogLevel: getEnv("LOG_LEVEL", "info"),
}
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getDuration(key string, fallback time.Duration) time.Duration {
if v := os.Getenv(key); v != "" {
d, err := time.ParseDuration(v)
if err == nil {
return d
}
}
return fallback
}
+283
View File
@@ -0,0 +1,283 @@
// Package docker 封装 Docker CLI 调用,负责查找、拉取和重建容器。
package docker
import (
"bytes"
"encoding/json"
"fmt"
"os/exec"
"strings"
"time"
)
const (
composeProjectLabel = "com.docker.compose.project"
composeServiceLabel = "com.docker.compose.service"
)
// Updater 封装 Docker 容器更新操作。
type Updater struct {
pullTimeout time.Duration
restartTimeout time.Duration
}
// NewUpdater 创建 Updater 实例。
func NewUpdater(pullTimeout, restartTimeout time.Duration) *Updater {
return &Updater{
pullTimeout: pullTimeout,
restartTimeout: restartTimeout,
}
}
// ContainerInfo 容器信息摘要。
type ContainerInfo struct {
ID string `json:"ID"`
Name string `json:"Name"`
Image string `json:"Image"`
Project string `json:"Project"`
Service string `json:"Service"`
}
// containerInspect 对应 docker inspect 的部分字段。
type containerInspect struct {
ID string `json:"Id"`
Name string `json:"Name"`
Config struct {
Image string `json:"Image"`
Cmd []string `json:"Cmd"`
Entrypoint []string `json:"Entrypoint"`
Env []string `json:"Env"`
ExposedPorts map[string]struct{} `json:"ExposedPorts"`
Labels map[string]string `json:"Labels"`
WorkingDir string `json:"WorkingDir"`
User string `json:"User"`
Hostname string `json:"Hostname"`
} `json:"Config"`
HostConfig struct {
NetworkMode string `json:"NetworkMode"`
Privileged bool `json:"Privileged"`
RestartPolicy struct {
Name string `json:"Name"`
} `json:"RestartPolicy"`
PortBindings map[string][]struct {
HostPort string `json:"HostPort"`
} `json:"PortBindings"`
Binds []string `json:"Binds"`
Links []string `json:"Links"`
ExtraHosts []string `json:"ExtraHosts"`
} `json:"HostConfig"`
Mounts []struct {
Type string `json:"Type"`
Source string `json:"Source"`
Target string `json:"Target"`
} `json:"Mounts"`
NetworkSettings struct {
Networks map[string]struct {
Aliases []string `json:"Aliases"`
} `json:"Networks"`
} `json:"NetworkSettings"`
}
// FindContainerByLabels 通过 compose project + service 标签查找容器。
func (u *Updater) FindContainerByLabels(project, service string) (*ContainerInfo, error) {
args := []string{
"ps",
"--filter", fmt.Sprintf("label=%s=%s", composeProjectLabel, project),
"--filter", fmt.Sprintf("label=%s=%s", composeServiceLabel, service),
"--format", "{{.ID}}\t{{.Image}}\t{{.Names}}",
"--latest",
}
out, err := u.docker(args...)
if err != nil {
return nil, fmt.Errorf("find container: %w", err)
}
lines := strings.Split(strings.TrimSpace(out), "\n")
if len(lines) == 0 || lines[0] == "" {
return nil, fmt.Errorf("container not found: project=%s service=%s", project, service)
}
parts := strings.SplitN(lines[0], "\t", 3)
if len(parts) < 3 {
return nil, fmt.Errorf("unexpected docker ps output: %s", out)
}
return &ContainerInfo{
ID: parts[0],
Image: parts[1],
Name: parts[2],
Project: project,
Service: service,
}, nil
}
// PullImage 拉取指定镜像。
func (u *Updater) PullImage(imageName string) error {
_, err := u.docker("pull", imageName)
return err
}
// RecreateContainer 重建容器:拉取 → 重命名 → 创建 → 启动 → 清理。
// 返回新容器 ID。镜像无变化时跳过重建。
func (u *Updater) RecreateContainer(project, service string) (string, error) {
info, err := u.FindContainerByLabels(project, service)
if err != nil {
return "", err
}
// 拉取新镜像
if err := u.PullImage(info.Image); err != nil {
return "", fmt.Errorf("pull: %w", err)
}
// 检查镜像是否有变化
oldImageID, err := u.getContainerImageID(info.ID)
if err != nil {
return "", fmt.Errorf("old image ID: %w", err)
}
newImageID, err := u.getImageID(info.Image)
if err != nil {
return "", fmt.Errorf("new image ID: %w", err)
}
if oldImageID == newImageID {
return info.ID, nil
}
// 检查旧容器配置
inspectJSON, err := u.inspectContainer(info.ID)
if err != nil {
return "", fmt.Errorf("inspect: %w", err)
}
oldName := strings.TrimPrefix(inspectJSON.Name, "/")
oldRenamed := oldName + "-old-" + time.Now().Format("150405")
// 重命名旧容器
if _, err := u.docker("rename", info.ID, oldRenamed); err != nil {
return "", fmt.Errorf("rename old: %w", err)
}
// 创建新容器
runArgs := u.buildRunArgs(inspectJSON, info.Image, oldName)
createOut, err := u.docker(runArgs...)
if err != nil {
_, _ = u.docker("rename", info.ID, oldName) // 恢复
return "", fmt.Errorf("create: %w", err)
}
newID := strings.TrimSpace(createOut)
// 删除旧容器
if _, err := u.docker("rm", "-f", info.ID); err != nil {
fmt.Printf("warning: remove old %s: %v\n", info.ID[:12], err)
}
return newID, nil
}
// RestartContainer 仅重启(会拉取最新镜像后再重建)。
func (u *Updater) RestartContainer(project, service string) (string, error) {
_, err := u.FindContainerByLabels(project, service)
if err != nil {
return "", err
}
return u.RecreateContainer(project, service)
}
func (u *Updater) getContainerImageID(containerID string) (string, error) {
return u.inspectField(containerID, "{{.Image}}")
}
func (u *Updater) getImageID(imageName string) (string, error) {
out, err := u.docker("image", "inspect", "--format", "{{.Id}}", imageName)
if err != nil {
return "", err
}
return strings.TrimSpace(out), nil
}
func (u *Updater) inspectField(containerID, format string) (string, error) {
out, err := u.docker("inspect", "--format", format, containerID)
if err != nil {
return "", err
}
return strings.TrimSpace(out), nil
}
func (u *Updater) inspectContainer(containerID string) (*containerInspect, error) {
out, err := u.docker("inspect", containerID)
if err != nil {
return nil, err
}
var containers []containerInspect
if err := json.Unmarshal([]byte(out), &containers); err != nil {
return nil, fmt.Errorf("parse inspect: %w", err)
}
if len(containers) == 0 {
return nil, fmt.Errorf("container not found: %s", containerID)
}
return &containers[0], nil
}
// buildRunArgs 从旧容器配置构建 docker run 参数。
func (u *Updater) buildRunArgs(inspect *containerInspect, imageName, containerName string) []string {
args := []string{"run", "-d"}
args = append(args, "--name", containerName)
rp := inspect.HostConfig.RestartPolicy.Name
if rp == "" {
rp = "unless-stopped"
}
args = append(args, "--restart", rp)
for _, env := range inspect.Config.Env {
args = append(args, "-e", env)
}
for _, m := range inspect.Mounts {
if m.Source == "" {
continue
}
args = append(args, "-v", fmt.Sprintf("%s:%s", m.Source, m.Target))
}
for port, bindings := range inspect.HostConfig.PortBindings {
for _, b := range bindings {
if b.HostPort == "" {
args = append(args, "-p", port)
} else {
args = append(args, "-p", fmt.Sprintf("%s:%s", b.HostPort, port))
}
}
}
if nm := inspect.HostConfig.NetworkMode; nm != "" && nm != "default" {
args = append(args, "--network", string(nm))
}
for _, h := range inspect.HostConfig.ExtraHosts {
args = append(args, "--add-host", h)
}
for k, v := range inspect.Config.Labels {
args = append(args, "-l", fmt.Sprintf("%s=%s", k, v))
}
if wd := inspect.Config.WorkingDir; wd != "" {
args = append(args, "-w", wd)
}
if u := inspect.Config.User; u != "" {
args = append(args, "-u", u)
}
if hn := inspect.Config.Hostname; hn != "" {
args = append(args, "--hostname", hn)
}
if inspect.HostConfig.Privileged {
args = append(args, "--privileged")
}
for _, link := range inspect.HostConfig.Links {
args = append(args, "--link", link)
}
args = append(args, imageName)
return args
}
// docker 执行 docker CLI 命令。
func (u *Updater) docker(args ...string) (string, error) {
cmd := exec.Command("docker", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("docker %s: %v\nstderr: %s",
strings.Join(args, " "), err, strings.TrimSpace(stderr.String()))
}
return stdout.String(), nil
}
+255
View File
@@ -0,0 +1,255 @@
package server
import (
"crypto/ecdsa"
"encoding/base64"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"time"
"gitea.songhuwan.com/actions/docker-compose-updater/internal/auth"
"gitea.songhuwan.com/actions/docker-compose-updater/internal/docker"
)
// Handler 持有 HTTP handler 的依赖。
type Handler struct {
verifyKey *ecdsa.PublicKey
sessionMgr *auth.SessionManager
nonceCache *auth.NonceCache
docker *docker.Updater
timeWindow time.Duration
}
// NewHandler 创建 Handler。
func NewHandler(
verifyKey *ecdsa.PublicKey,
sessionMgr *auth.SessionManager,
nonceCache *auth.NonceCache,
docker *docker.Updater,
timeWindow time.Duration,
) *Handler {
return &Handler{
verifyKey: verifyKey,
sessionMgr: sessionMgr,
nonceCache: nonceCache,
docker: docker,
timeWindow: timeWindow,
}
}
// --- 请求/响应结构 ---
// phase1Req 第一阶段请求。
type phase1Req struct {
V int `json:"v"`
TS int64 `json:"ts"`
Nonce string `json:"nonce"`
Sig string `json:"sig"`
}
// phase2Req 第二阶段请求。
type phase2Req struct {
V int `json:"v"`
TS int64 `json:"ts"`
Nonce string `json:"nonce"`
KeyID string `json:"key_id"`
Data string `json:"data"` // AES-GCM 密文 base64
}
// plainPayload 解密后的明文请求。
type plainPayload struct {
Project string `json:"project"`
Service string `json:"service"`
Action string `json:"action"` // update / pull / restart
}
// --- 处理函数 ---
// HandleSession Phase 1: 验证签名 → 生成会话密钥 → 返回。
func (h *Handler) HandleSession(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "only POST allowed")
return
}
var req phase1Req
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
if req.V != 1 {
writeError(w, http.StatusBadRequest, "unsupported protocol version")
return
}
if req.Sig == "" {
writeError(w, http.StatusBadRequest, "missing sig")
return
}
if err := h.checkTimestamp(req.TS); err != nil {
writeError(w, http.StatusUnauthorized, err.Error())
return
}
if !h.nonceCache.Check(req.Nonce) {
writeError(w, http.StatusUnauthorized, "nonce reused")
return
}
// 验证 ECDSA 签名:sign("v.ts.nonce")
signData := fmt.Sprintf("%d.%d.%s", req.V, req.TS, req.Nonce)
if !auth.Verify(h.verifyKey, []byte(signData), req.Sig) {
writeError(w, http.StatusUnauthorized, "signature verification failed")
return
}
// 生成会话密钥,用 nonce 作为 key_id
key := make([]byte, 32)
// 使用随机数作为会话密钥
keyBytes, err := h.sessionMgr.GenerateKey(req.Nonce)
if err != nil {
slog.Error("generate session key", "error", err)
writeError(w, http.StatusInternalServerError, "internal error")
return
}
_ = key
writeJSON(w, http.StatusOK, map[string]any{
"key": base64.StdEncoding.EncodeToString(keyBytes),
"key_id": req.Nonce,
"expires_in": 30,
})
}
// HandleHook Phase 2: 解密 payload → 执行 Docker 操作。
func (h *Handler) HandleHook(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "only POST allowed")
return
}
var req phase2Req
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
if req.V != 1 {
writeError(w, http.StatusBadRequest, "unsupported protocol version")
return
}
if err := h.checkTimestamp(req.TS); err != nil {
writeError(w, http.StatusUnauthorized, err.Error())
return
}
if !h.nonceCache.Check(req.Nonce) {
writeError(w, http.StatusUnauthorized, "nonce reused")
return
}
// 获取会话密钥
sessionKey := h.sessionMgr.GetKey(req.KeyID)
if sessionKey == nil {
writeError(w, http.StatusUnauthorized, "invalid or expired session key")
return
}
// 解密 payload
ciphertext, err := base64.StdEncoding.DecodeString(req.Data)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid data encoding")
return
}
plaintext, err := auth.Decrypt(ciphertext, sessionKey)
if err != nil {
writeError(w, http.StatusUnauthorized, "decrypt failed")
return
}
var payload plainPayload
if err := json.Unmarshal(plaintext, &payload); err != nil {
writeError(w, http.StatusBadRequest, "invalid payload json")
return
}
if payload.Project == "" || payload.Service == "" || payload.Action == "" {
writeError(w, http.StatusBadRequest, "project, service, action required")
return
}
slog.Info("hook",
"project", payload.Project,
"service", payload.Service,
"action", payload.Action,
)
var result string
switch payload.Action {
case "pull":
info, err := h.docker.FindContainerByLabels(payload.Project, payload.Service)
if err != nil {
writeError(w, http.StatusNotFound, err.Error())
return
}
if err := h.docker.PullImage(info.Image); err != nil {
writeError(w, http.StatusInternalServerError, "pull: "+err.Error())
return
}
result = "image pulled"
case "update":
id, err := h.docker.RecreateContainer(payload.Project, payload.Service)
if err != nil {
writeError(w, http.StatusInternalServerError, "update: "+err.Error())
return
}
result = fmt.Sprintf("container recreated: %s", id[:12])
case "restart":
id, err := h.docker.RestartContainer(payload.Project, payload.Service)
if err != nil {
writeError(w, http.StatusInternalServerError, "restart: "+err.Error())
return
}
result = fmt.Sprintf("container restarted: %s", id[:12])
default:
writeError(w, http.StatusBadRequest, "unknown action: "+payload.Action)
return
}
writeJSON(w, http.StatusOK, map[string]string{
"status": "ok",
"project": payload.Project,
"service": payload.Service,
"action": payload.Action,
"message": result,
})
}
// Health 健康检查。
func (h *Handler) Health(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
// --- 工具函数 ---
func (h *Handler) checkTimestamp(ts int64) error {
now := time.Now().Unix()
diff := now - ts
if diff < 0 {
diff = -diff
}
if diff > int64(h.timeWindow.Seconds()) {
return fmt.Errorf("timestamp outside window")
}
return nil
}
func writeJSON(w http.ResponseWriter, status int, v any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(v)
}
func writeError(w http.ResponseWriter, status int, msg string) {
writeJSON(w, status, map[string]string{"error": msg})
}
+59
View File
@@ -0,0 +1,59 @@
package server
import (
"log/slog"
"net/http"
"time"
"crypto/ecdsa"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"gitea.songhuwan.com/actions/docker-compose-updater/internal/auth"
"gitea.songhuwan.com/actions/docker-compose-updater/internal/docker"
)
// Server 封装 HTTP 服务器。
type Server struct {
addr string
router *chi.Mux
srv *http.Server
}
// New 创建 Server。
func New(
addr string,
verifyKey *ecdsa.PublicKey,
updater *docker.Updater,
sessionTTL time.Duration,
nonceTTL time.Duration,
timeWindow time.Duration,
) *Server {
sessionMgr := auth.NewSessionManager(sessionTTL)
nonceCache := auth.NewNonceCache(nonceTTL)
h := NewHandler(verifyKey, sessionMgr, nonceCache, updater, timeWindow)
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
r.Use(middleware.RequestSize(1024 * 1024)) // 1MB
r.Get("/health", h.Health)
r.Post("/session", h.HandleSession)
r.Post("/hook", h.HandleHook)
return &Server{
addr: addr,
router: r,
}
}
// Start 启动 HTTP 服务器。
func (s *Server) Start() error {
s.srv = &http.Server{
Addr: s.addr,
Handler: s.router,
}
slog.Info("server starting", "addr", s.addr)
return s.srv.ListenAndServe()
}