diff --git a/dnsconfig/toml_zones.go b/dnsconfig/toml_zones.go new file mode 100644 index 0000000..e3fda6f --- /dev/null +++ b/dnsconfig/toml_zones.go @@ -0,0 +1,81 @@ +package dnsconfig + +import ( + "fmt" + "os" + "strings" + + "github.com/BurntSushi/toml" +) + +type TomlVariables struct { + Name string `toml:"name"` + Type string `toml:"type"` + Value string `toml:"value"` +} + +type TomlRecord struct { + Name string `toml:"name"` + Type string `toml:"type"` + Value string `toml:"value"` + TTL int `toml:"ttl"` +} + +type TomlZone struct { + Domain string `toml:"domain"` + DefaultTTL int `toml:"default_ttl"` + TomlRecords []TomlRecord `toml:"record"` + TomlVariables []TomlVariables `toml:"variable"` +} + +func GetTomlZones() ([]TomlZone, error) { + var zones []TomlZone + + files, err := os.ReadDir("./zones/") + if err != nil { + return zones, err + } + + for _, file := range files { + name := file.Name() + path := fmt.Sprintf("./zones/%s", name) + + if !strings.HasSuffix(strings.ToLower(name), ".toml") { + continue + } + + var zone TomlZone + _, err = toml.DecodeFile(path, &zone) + if err != nil { + continue + } + + if len(zone.Domain) == 0 { + continue + } + + if zone.DefaultTTL <= 0 { + zone.DefaultTTL = 3600 + } + + zones = append(zones, zone) + } + + return zones, nil +} + +func CreateTomlZone(zone TomlZone) error { + if len(zone.Domain) == 0 { + return fmt.Errorf("zone name cannot be empty") + } + + name := strings.ReplaceAll(strings.ToLower(zone.Domain), ".", "_") + path := fmt.Sprintf("./zones/%s.toml", name) + + file, err := os.Create(path) + if err != nil { + return err + } + + return toml.NewEncoder(file).Encode(&zone) +} diff --git a/dnsconfig/zones.go b/dnsconfig/zones.go index 4a9699e..72b0984 100644 --- a/dnsconfig/zones.go +++ b/dnsconfig/zones.go @@ -1,55 +1,142 @@ package dnsconfig import ( - "fmt" - "os" "strings" - "github.com/BurntSushi/toml" + "git.readonly.ch/bouzoure/gestion-dns/helpers" ) +type Variable struct { + Name string + Type string + Value string +} + type Record struct { - Name string `toml:"name"` - Type string `toml:"type"` - Value string `toml:"value"` - Flat bool `toml:"flat"` - TTL int `toml:"ttl"` + Name string + Type string + Value string + TTL int } type Zone struct { - Domain string `toml:"domain"` - DefaultTTL int `toml:"default_ttl"` - Records []Record `toml:"record"` + Domain string + Records []Record } func GetZones() ([]Zone, error) { + log := helpers.GetLogger() var zones []Zone - files, err := os.ReadDir("./zones/") + tomlZones, err := GetTomlZones() if err != nil { return zones, err } - for _, file := range files { - name := file.Name() - path := fmt.Sprintf("./zones/%s", name) - - if !strings.HasSuffix(strings.ToLower(name), ".toml") { - continue - } - + for _, tomlZone := range tomlZones { var zone Zone - _, err = toml.DecodeFile(path, &zone) - if err != nil { - continue + zone.Domain = tomlZone.Domain + + var tmpVariablesTypes []Variable + for _, tomlVariable := range tomlZone.TomlVariables { + if strings.ToUpper(tomlVariable.Type) == "A+AAAA" { + tmpVariablesTypes = append(tmpVariablesTypes, Variable{ + Name: tomlVariable.Name, + Value: tomlVariable.Value, + Type: "A", + }) + + tmpVariablesTypes = append(tmpVariablesTypes, Variable{ + Name: tomlVariable.Name, + Value: tomlVariable.Value, + Type: "AAAA", + }) + } else { + tmpVariablesTypes = append(tmpVariablesTypes, Variable(tomlVariable)) + } } - if len(zone.Domain) == 0 { - continue + var variablesResolved []Variable + for _, tmpVariable := range tmpVariablesTypes { + if strings.HasPrefix(tmpVariable.Value, "@") { + dnsName := strings.Replace(tmpVariable.Value, "@", "", 1) + resolveResults := helpers.ResolveRecord(dnsName, tmpVariable.Type) + + for _, resolveResult := range resolveResults { + variablesResolved = append(variablesResolved, Variable{ + Name: tmpVariable.Name, + Value: resolveResult, + Type: tmpVariable.Type, + }) + } + } else { + variablesResolved = append(variablesResolved, tmpVariable) + } } - if zone.DefaultTTL <= 0 { - zone.DefaultTTL = 3600 + var tmpRecordsTypes []Record + for _, tomlRecord := range tomlZone.TomlRecords { + if strings.EqualFold(tomlRecord.Type, "A+AAAA") { + tmpRecordsTypes = append(tmpRecordsTypes, Record{ + Name: tomlRecord.Name, + Value: tomlRecord.Value, + TTL: tomlRecord.TTL, + Type: "A", + }) + + tmpRecordsTypes = append(tmpRecordsTypes, Record{ + Name: tomlRecord.Name, + Value: tomlRecord.Value, + TTL: tomlRecord.TTL, + Type: "AAAA", + }) + } else { + tmpRecordsTypes = append(tmpRecordsTypes, Record(tomlRecord)) + } + } + + var tmpRecordsResolved []Record + for _, tmpRecord := range tmpRecordsTypes { + if strings.HasPrefix(tmpRecord.Value, "@") { + dnsName := strings.Replace(tmpRecord.Value, "@", "", 1) + resolveResults := helpers.ResolveRecord(dnsName, tmpRecord.Type) + + for _, resolveResult := range resolveResults { + tmpRecordsResolved = append(tmpRecordsResolved, Record{ + Name: tmpRecord.Name, + Value: resolveResult, + TTL: tmpRecord.TTL, + Type: tmpRecord.Type, + }) + } + } else { + tmpRecordsResolved = append(tmpRecordsResolved, tmpRecord) + } + } + + for _, tmpRecord := range tmpRecordsResolved { + if tmpRecord.TTL <= 0 { + tmpRecord.TTL = tomlZone.DefaultTTL + } + + // TODO: Handle vars + if strings.HasPrefix(tmpRecord.Value, "$") { + varName := strings.Replace(tmpRecord.Value, "$", "", 1) + found := false + for _, variable := range variablesResolved { + if strings.EqualFold(varName, variable.Name) && strings.EqualFold(tmpRecord.Type, variable.Type) { + found = true + tmpRecord.Value = variable.Value + } + } + + if !found { + log.Error("Could not find a variable with matching type", "name", varName, "type", tmpRecord.Type) + continue + } + } + + zone.Records = append(zone.Records, tmpRecord) } zones = append(zones, zone) @@ -57,19 +144,3 @@ func GetZones() ([]Zone, error) { return zones, nil } - -func CreateZone(zone Zone) error { - if len(zone.Domain) == 0 { - return fmt.Errorf("zone name cannot be empty") - } - - name := strings.ReplaceAll(strings.ToLower(zone.Domain), ".", "_") - path := fmt.Sprintf("./zones/%s.toml", name) - - file, err := os.Create(path) - if err != nil { - return err - } - - return toml.NewEncoder(file).Encode(&zone) -} diff --git a/helpers/resolver.go b/helpers/resolver.go index 8e18015..70cd476 100644 --- a/helpers/resolver.go +++ b/helpers/resolver.go @@ -4,11 +4,13 @@ import ( "github.com/domainr/dnsr" ) -func ResolveRecord(recordName, recordType string) string { - r := dnsr.NewResolver(dnsr.WithCache(0)) +func ResolveRecord(recordName, recordType string) []string { + var records []string + + r := dnsr.NewResolver(dnsr.WithCache(1000), dnsr.WithTCPRetry()) for _, rr := range r.Resolve(recordName, recordType) { - return rr.Value + records = append(records, rr.Value) } - return "" + return records } diff --git a/main.go b/main.go index bf6720e..cec124b 100644 --- a/main.go +++ b/main.go @@ -68,22 +68,21 @@ func main() { log.Info("Creating DNS config file for existing zone", "zone", hZone.Name) - zone := dnsconfig.Zone{ + zone := dnsconfig.TomlZone{ Domain: strings.ToLower(hZone.Name), DefaultTTL: 3600, } for _, record := range hZone.Records { - zone.Records = append(zone.Records, dnsconfig.Record{ + zone.TomlRecords = append(zone.TomlRecords, dnsconfig.TomlRecord{ Name: record.Name, Type: record.Type, Value: record.Value, TTL: record.TTL, - Flat: false, }) } - err = dnsconfig.CreateZone(zone) + err = dnsconfig.CreateTomlZone(zone) if err != nil { log.Fatal(err) } @@ -109,31 +108,12 @@ func main() { log.Info("Calculating sync diff (step 1: create/update)") var keepTheseIds []string for _, zone := range zones { - if zone.DefaultTTL <= 0 { - zone.DefaultTTL = 3600 - } - for _, hZone := range hZones { if strings.EqualFold(zone.Domain, hZone.Name) { log.Info("Calculating sync diff for zone", "name", zone.Domain) var alreadyFoundIds []string for _, record := range zone.Records { - if record.Flat { - record.Value = helpers.ResolveRecord( - record.Value, record.Type, - ) - - if len(record.Value) == 0 { - log.Error("Could not flatten record, skipping", "record", record) - continue - } - } - - if record.TTL <= 0 { - record.TTL = zone.DefaultTTL - } - var id string for _, hRecord := range hZone.Records { if slices.Contains(alreadyFoundIds, hRecord.ID) {