Check for valid callback source address

This commit is contained in:
2026-03-14 16:16:10 +03:00
parent e4bfb49f21
commit b2566813ac
5 changed files with 43 additions and 1 deletions

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"regexp"
@@ -15,6 +16,7 @@ import (
"payouts/internal/service/database"
"payouts/internal/service/database/orm"
"payouts/internal/service/yookassa"
yookassaConf "payouts/internal/service/yookassa/config"
)
const (
@@ -30,6 +32,7 @@ type payoutHandler struct {
dbService database.Service
cacheService cache.Service
yooKassa yookassa.Service
yookassaConf yookassaConf.YooKassa
}
// Params represents the module input params
@@ -47,6 +50,7 @@ func NewPayoutHandler(p Params) (Handler, error) {
dbService: p.DbService,
cacheService: p.CacheService,
yooKassa: p.YooKassa,
yookassaConf: p.YooKassa.GetConfig(),
}, nil
}
@@ -70,6 +74,27 @@ func (p *payoutHandler) getSession(r *http.Request) (*orm.User, error) {
}
func (p *payoutHandler) checkAllowedIpCallback(ipStr string) bool {
ipWithoutPort, _, _ := net.SplitHostPort(ipStr)
ip := net.ParseIP(ipWithoutPort)
if ip == nil {
slog.Error(fmt.Sprintf("Invalid IP: %s", ipStr))
return false
}
for _, subnetStr := range p.yookassaConf.AllowedCallbackSubnets {
_, ipNet, err := net.ParseCIDR(subnetStr)
if err != nil {
slog.Error(fmt.Sprintf("Invalid subnet CIDR: %v", err))
continue
}
if ipNet.Contains(ip) {
return true
}
}
return false
}
// GetSbpBanks implements [Handler].
func (p *payoutHandler) GetSbpBanks(w http.ResponseWriter, r *http.Request) {
panic("unimplemented")
@@ -83,7 +108,7 @@ func (p *payoutHandler) PayoutCreate(w http.ResponseWriter, r *http.Request) {
_, err := p.getSession(r)
if err != nil {
errResponse("unautiorized", err, http.StatusUnauthorized)
errResponse("unauthorized", err, http.StatusUnauthorized)
}
panic("unimplemented")
@@ -95,5 +120,12 @@ func (p *payoutHandler) PayoutCallback(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
decoder.Decode(&inData)
// todo: check also the X-real-ip and/or X-Forwarded-For
if !p.checkAllowedIpCallback(r.RemoteAddr) {
slog.Error(fmt.Sprintf("Callback came from unallowed ip: %s", r.RemoteAddr))
http.Error(w, "unallowed", http.StatusForbidden)
return
}
slog.Info(fmt.Sprintf("Received callback from %s with object %v with headers %v", r.RemoteAddr, inData, r.Header))
}