diff --git a/internal/watch/record.go b/internal/watch/record.go index 245dfdb..f4bb8a5 100644 --- a/internal/watch/record.go +++ b/internal/watch/record.go @@ -8,111 +8,145 @@ import ( "github.com/codfrm/cago/pkg/logger" "github.com/codfrm/dnspod-watch/pkg/pushcat" - "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" dnspod "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod/v20210323" "go.uber.org/zap" ) type record struct { - w *watch - record *dnspod.RecordListItem - isDisable bool - domain, value string - logger *logger.CtxLogger + w *watch + record *dnspod.RecordListItem + loadBalance *dnspod.RecordListItem + isDisable bool + domain *CheckDomain + logger *logger.CtxLogger } -func newRecord(w *watch, r *dnspod.RecordListItem, domain, value string) *record { - return &record{ +func newRecord(w *watch, r *dnspod.RecordListItem, domain *CheckDomain) (*record, error) { + ret := &record{ w: w, record: r, isDisable: false, - domain: domain, value: value, + domain: domain, logger: logger.NewCtxLogger(logger.Default()).With( - zap.String("domain", domain), zap.String("value", value), + zap.String("domain", domain.Domain), zap.String("value", domain.Value), ), } + if domain.LoadBalance != nil { + var err error + ret.loadBalance, err = w.queryRecord(domain.Domain, domain.Name, + domain.LoadBalance.Value, domain.LoadBalance.Line) + if err != nil { + return nil, err + } + } + return ret, nil } // watch 每分钟检查ip是否可以访问, 无法访问自动暂停记录 func (r *record) watch(ctx context.Context) { - t := time.NewTicker(time.Minute) - lastSwitch := time.Now() - count := 0 + t := time.NewTicker(time.Second) + s := newRetry() + loadBalance := newRetry() for { select { case <-t.C: - // 检查ip是否可以访问 - count++ - if err := r.checkIP(ctx, r.value); err != nil { + duration, err := r.checkIP(ctx, r.domain.Value) + if err != nil { r.logger.Ctx(ctx).Error("check ip err", zap.Error(err)) - // 连续3次无法访问,暂停记录 - if !r.isDisable && count > 3 { - count = 0 - // 暂停记录 - request := dnspod.NewModifyRecordStatusRequest() - request.SetContext(ctx) - request.Domain = common.StringPtr(r.domain) - request.RecordId = common.Uint64Ptr(*r.record.RecordId) - request.Status = common.StringPtr("DISABLE") - _, err := r.w.client.ModifyRecordStatus(request) - msg := fmt.Sprintf("域名: %s, 记录: %s, ip无法访问,暂停记录", r.domain, r.value) - if err != nil { - r.logger.Ctx(ctx).Error("modify record status err", zap.Error(err)) - msg += "\n记录修改失败: " + err.Error() - } else { - r.logger.Ctx(ctx).Info("modify record status success", - zap.String("status", "DISABLE")) - r.isDisable = true - } - if err := pushcat.Send(ctx, "ip无法访问,暂停记录", msg); err != nil { - r.logger.Ctx(ctx).Error("发送通知错误", - zap.Error(err), - zap.String("msg", msg)) - } + } else { + r.logger.Ctx(ctx).Info("check ip ok", zap.Duration("duration", duration)) + } + _ = s.check(err == nil, func() error { + // 增加负载均衡 + if r.domain.LoadBalance != nil { + // 判断延迟是否超过200毫秒 + loadBalance.check(duration > time.Millisecond*100, func() error { + // 开启记录 + msg := fmt.Sprintf("开启负载均衡 域名: %s, 记录: %s", + r.domain.Domain, r.domain.LoadBalance.Value) + enableErr := r.w.enable(ctx, r.domain.Domain, *r.loadBalance.RecordId) + if enableErr != nil { + r.logger.Ctx(ctx).Error("modify record status err", zap.Error(enableErr)) + msg += "\n记录修改失败: " + enableErr.Error() + } else { + r.logger.Ctx(ctx).Info("modify record status success", + zap.String("status", "ENABLE")) + r.isDisable = false + } + if pushErr := pushcat.Send(ctx, "开启负载均衡", msg); pushErr != nil { + r.logger.Ctx(ctx).Error("发送通知错误", + zap.Error(pushErr), + zap.String("msg", msg)) + } + return enableErr + }, func() error { + // 开启记录 + msg := fmt.Sprintf("关闭负载均衡 域名: %s, 记录: %s", + r.domain.Domain, r.domain.LoadBalance.Value) + disableErr := r.w.disable(ctx, r.domain.Domain, *r.loadBalance.RecordId) + if disableErr != nil { + r.logger.Ctx(ctx).Error("modify record status err", zap.Error(disableErr)) + msg += "\n记录修改失败: " + disableErr.Error() + } else { + r.logger.Ctx(ctx).Info("modify record status success", + zap.String("status", "DISABLE")) + r.isDisable = false + } + if pushErr := pushcat.Send(ctx, "关闭负载均衡", msg); pushErr != nil { + r.logger.Ctx(ctx).Error("发送通知错误", + zap.Error(pushErr), + zap.String("msg", msg)) + } + return disableErr + }) } - } else if r.isDisable && count > 3 { - // 上次切换时间超过30分钟才能再次切换 - if time.Since(lastSwitch) < time.Minute*30 { - r.logger.Ctx(ctx).Info("ip可以访问,但是上次切换时间不足30分钟") - continue - } - lastSwitch = time.Now() - - // 检查连续成功3次,开启记录 - count = 0 - request := dnspod.NewModifyRecordStatusRequest() - request.SetContext(ctx) - request.Domain = common.StringPtr(r.domain) - request.RecordId = common.Uint64Ptr(*r.record.RecordId) - request.Status = common.StringPtr("ENABLE") - _, err := r.w.client.ModifyRecordStatus(request) - msg := fmt.Sprintf("域名: %s, 记录: %s, ip可以访问,开启记录", r.domain, r.value) - if err != nil { - r.logger.Ctx(ctx).Error("modify record status err", zap.Error(err)) - msg += "\n记录修改失败: " + err.Error() + // 开启记录 + msg := fmt.Sprintf("域名: %s, 记录: %s, ip可以访问,开启记录", r.domain.Domain, r.domain.Value) + enableErr := r.w.enable(ctx, r.domain.Domain, *r.record.RecordId) + if enableErr != nil { + r.logger.Ctx(ctx).Error("modify record status err", zap.Error(enableErr)) + msg += "\n记录修改失败: " + enableErr.Error() } else { r.logger.Ctx(ctx).Info("modify record status success", zap.String("status", "ENABLE")) r.isDisable = false } - if err := pushcat.Send(ctx, "ip可以访问,开启记录", msg); err != nil { + if pushErr := pushcat.Send(ctx, "ip可以访问,开启记录", msg); pushErr != nil { r.logger.Ctx(ctx).Error("发送通知错误", - zap.Error(err), + zap.Error(pushErr), zap.String("msg", msg)) } - } else { - r.logger.Ctx(ctx).Info("ip is ok") - } + return enableErr + }, func() error { + // 暂停记录 + msg := fmt.Sprintf("域名: %s, 记录: %s, ip无法访问,暂停记录", r.domain.Domain, r.domain.Value) + disableErr := r.w.disable(ctx, r.domain.Domain, *r.record.RecordId) + if disableErr != nil { + r.logger.Ctx(ctx).Error("modify record status err", zap.Error(disableErr)) + msg += "\n记录修改失败: " + disableErr.Error() + } else { + r.logger.Ctx(ctx).Info("modify record status success", + zap.String("status", "DISABLE")) + r.isDisable = true + } + if pushErr := pushcat.Send(ctx, "ip无法访问,暂停记录", msg); pushErr != nil { + r.logger.Ctx(ctx).Error("发送通知错误", + zap.Error(pushErr), + zap.String("msg", msg)) + } + return disableErr + }) case <-ctx.Done(): t.Stop() } } } -func (r *record) checkIP(ctx context.Context, ip string) error { +func (r *record) checkIP(ctx context.Context, ip string) (time.Duration, error) { + ts := time.Now() con, err := net.DialTimeout("tcp", ip+":80", time.Second*10) if err != nil { - return err + return 0, err } - return con.Close() + return time.Since(ts), con.Close() } diff --git a/internal/watch/switch.go b/internal/watch/switch.go new file mode 100644 index 0000000..cea7dfe --- /dev/null +++ b/internal/watch/switch.go @@ -0,0 +1,51 @@ +package watch + +import "time" + +type retry struct { + lastStatus bool + currentStatus bool + count int + lastTime time.Time +} + +func newRetry() *retry { + return &retry{ + count: 0, + //lastTime: time.Now(), + } +} + +func (r *retry) check(check bool, ok func() error, bad func() error) error { + if check { + if r.lastStatus { + r.count += 1 + } else { + r.count = 0 + } + r.lastStatus = true + if !r.currentStatus && r.count > 3 { + if time.Since(r.lastTime) < time.Minute*60 { + return nil + } + if err := ok(); err == nil { + r.currentStatus = true + r.lastTime = time.Now() + } + } + } else { + if !r.lastStatus { + r.count += 1 + } else { + r.count = 0 + } + r.lastStatus = false + if r.currentStatus && r.count > 3 { + if err := bad(); err == nil { + r.currentStatus = false + r.lastTime = time.Now() + } + } + } + return nil +} diff --git a/internal/watch/utils.go b/internal/watch/utils.go index 847473c..f73648e 100644 --- a/internal/watch/utils.go +++ b/internal/watch/utils.go @@ -1,15 +1,15 @@ package watch import ( + "context" "errors" - "github.com/codfrm/cago/pkg/logger" "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" dnspod "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod/v20210323" "go.uber.org/zap" ) -func (w *watch) queryRecord(domain, name, value string) (*dnspod.RecordListItem, error) { +func (w *watch) queryRecord(domain, name, value, line string) (*dnspod.RecordListItem, error) { // 实例化一个请求对象,每个接口都会对应一个request对象 request := dnspod.NewDescribeRecordListRequest() @@ -22,10 +22,40 @@ func (w *watch) queryRecord(domain, name, value string) (*dnspod.RecordListItem, } for _, v := range response.Response.RecordList { if *v.Name == name && *v.Value == value { + if line != "" && *v.Line != line { + continue + } logger.Default().Info("record found", zap.Any("record", v)) return v, nil } } return nil, errors.New("record not found") - +} + +func (w *watch) enable(ctx context.Context, domain string, recordId uint64) error { + // 开启记录 + request := dnspod.NewModifyRecordStatusRequest() + request.SetContext(ctx) + request.Domain = common.StringPtr(domain) + request.RecordId = common.Uint64Ptr(recordId) + request.Status = common.StringPtr("ENABLE") + _, err := w.client.ModifyRecordStatus(request) + if err != nil { + return err + } + return nil +} + +func (w *watch) disable(ctx context.Context, domain string, recordId uint64) error { + // 开启记录 + request := dnspod.NewModifyRecordStatusRequest() + request.SetContext(ctx) + request.Domain = common.StringPtr(domain) + request.RecordId = common.Uint64Ptr(recordId) + request.Status = common.StringPtr("DISABLE") + _, err := w.client.ModifyRecordStatus(request) + if err != nil { + return err + } + return nil } diff --git a/internal/watch/watch.go b/internal/watch/watch.go index 48328a0..f5a2038 100644 --- a/internal/watch/watch.go +++ b/internal/watch/watch.go @@ -12,10 +12,17 @@ import ( "go.uber.org/zap" ) +type LoadBalance struct { + Value string // 记录值 + Line string // 线路 +} + type CheckDomain struct { - Domain string // 域名 - Name string // 记录名 - Value []string // 记录值 + Domain string // 域名 + Name string // 记录名 + Value string // 记录值 + Line string // 线路 + LoadBalance *LoadBalance `yaml:"loadBalance"` // 负载均衡 } type Config struct { @@ -57,14 +64,15 @@ func (w *watch) Start(ctx context.Context, cfg *configs.Config) error { return err } for _, c := range w.config.CheckDomain { - for _, v := range c.Value { - r, err := w.queryRecord(c.Domain, c.Name, v) - if err != nil { - return err - } - record := newRecord(w, r, c.Domain, v) - go record.watch(ctx) + r, err := w.queryRecord(c.Domain, c.Name, c.Value, c.Line) + if err != nil { + return err } + record, err := newRecord(w, r, c) + if err != nil { + return err + } + go record.watch(ctx) } return nil }