Skip to content

Commit

Permalink
AliasMangler
Browse files Browse the repository at this point in the history
Add a new mangler that allows us to have aliases for dials tags to
facilitate migrating from one name to another seamlessly.  If both the
old name and the new name are set, an error will be returned because the
expectation is that you're using one or the other exclusively.

Also, fix a bug in `ez` where an error thrown by a Source may cause the
process to hang.
  • Loading branch information
sergiosalvatore committed Aug 23, 2024
1 parent 402b482 commit 0ff99d6
Show file tree
Hide file tree
Showing 8 changed files with 374 additions and 35 deletions.
21 changes: 19 additions & 2 deletions common/tags.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
// Package common provides constants that are used among different dials sources
package common

// DialsTagName is the name of the dials tag.
const DialsTagName = "dials"
const (
// DialsTagName is the name of the dials tag.
DialsTagName = "dials"

// DialsEnvTagName is the name of the dialsenv tag.
DialsEnvTagName = "dialsenv"

// DialsFlagTagName is the name of the dialsflag tag.
DialsFlagTagName = "dialsflag"

// DialsPFlagTagName is the name of the dialspflag tag.
DialsPFlagTag = "dialspflag"

// DialsFlagAliasTag is the name of the dialsflagalias tag.
DialsPFlagShortTag = "dialspflagshort"

// HelpTextTag is the name of the struct tag for flag descriptions
DialsHelpTextTag = "dialsdesc"
)
23 changes: 12 additions & 11 deletions ez/ez.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,6 @@ func ConfigFileEnvFlagDecoderFactoryParams[T any, TP ConfigWithConfigPath[T]](ct
flagSrc = fset
}

// If file-watching is not enabled, we should shutdown the monitor
// goroutine when exiting this function.
// Usually `dials.Config` is smart enough not to start a monitor when
// there are no `Watcher` implementations in the source-list, but the
// `Blank` source uses `Watcher` for its core functionality, so we need
// to shutdown the blank source to actually clean up resources.
if !params.WatchConfigFile {
defer blank.Done(ctx)
}

dp := dials.Params[T]{
// Set the OnNewConfig callback. It'll be suppressed by the
// CallGlobalCallbacksAfterVerificationEnabled until just before we return.
Expand All @@ -199,6 +189,16 @@ func ConfigFileEnvFlagDecoderFactoryParams[T any, TP ConfigWithConfigPath[T]](ct
return nil, err
}

// If file-watching is not enabled, we should shutdown the monitor
// goroutine when exiting this function.
// Usually `dials.Config` is smart enough not to start a monitor when
// there are no `Watcher` implementations in the source-list, but the
// `Blank` source uses `Watcher` for its core functionality, so we need
// to shutdown the blank source to actually clean up resources.
if !params.WatchConfigFile {
defer blank.Done(ctx)
}

basecfg := d.View()
cfgPath, filepathSet := (TP)(basecfg).ConfigPath()
if !filepathSet {
Expand All @@ -219,7 +219,8 @@ func ConfigFileEnvFlagDecoderFactoryParams[T any, TP ConfigWithConfigPath[T]](ct
return nil, fmt.Errorf("decoderFactory provided a nil decoder for path: %s", cfgPath)
}

manglers := make([]transform.Mangler, 0, 2)
manglers := make([]transform.Mangler, 0, 3)
manglers[0] = &transform.AliasMangler{}

if params.FileFieldNameEncoder != nil {
tagDecoder := params.DialsTagNameDecoder
Expand Down
29 changes: 28 additions & 1 deletion ez/ez_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
type config struct {
// Path will contain the path to the config file and will be set by
// environment variable
Path string `dials:"CONFIGPATH"`
Path string `dials:"CONFIGPATH" dialsalias:"ALTCONFIGPATH"`
Val1 int `dials:"Val1"`
Val2 string `dials:"Val2"`
Set map[string]struct{} `dials:"Set"`
Expand Down Expand Up @@ -59,6 +59,33 @@ func TestYAMLConfigEnvFlagWithValidConfig(t *testing.T) {
assert.EqualValues(t, expectedConfig, *populatedConf)
}

func TestYAMLConfigEnvFlagWithValidConfigAndAlias(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

envErr := os.Setenv("ALTCONFIGPATH", "../testhelper/testconfig.yaml")
require.NoError(t, envErr)
defer os.Unsetenv("ALTCONFIGPATH")

c := &config{}
view, dialsErr := YAMLConfigEnvFlag(ctx, c, Params[config]{})
require.NoError(t, dialsErr)

// Val1 and Val2 come from the config file and Path will be populated from env variable
expectedConfig := config{
Path: "../testhelper/testconfig.yaml",
Val1: 456,
Val2: "hello-world",
Set: map[string]struct{}{
"Keith": {},
"Gary": {},
"Jack": {},
},
}
populatedConf := view.View()
assert.EqualValues(t, expectedConfig, *populatedConf)
}

type beatlesConfig struct {
YAMLPath string
BeatlesMembers map[string]string
Expand Down
12 changes: 6 additions & 6 deletions sources/env/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import (
"github.com/vimeo/dials/transform"
)

const envTagName = "dialsenv"

// Source implements the dials.Source interface to set configuration from
// environment variables.
type Source struct {
Expand All @@ -36,10 +34,12 @@ func (e *Source) Value(_ context.Context, t *dials.Type) (reflect.Value, error)
// reformat the tags so they are SCREAMING_SNAKE_CASE
reformatTagMangler := tagformat.NewTagReformattingMangler(common.DialsTagName, caseconversion.DecodeGoTags, caseconversion.EncodeUpperSnakeCase)
// copy tags from "dials" to "dialsenv" tag
tagCopyingMangler := &tagformat.TagCopyingMangler{SrcTag: common.DialsTagName, NewTag: envTagName}
tagCopyingMangler := &tagformat.TagCopyingMangler{SrcTag: common.DialsTagName, NewTag: common.DialsEnvTagName}
// convert all the fields in the flattened struct to string type so the environment variables can be set
stringCastingMangler := &transform.StringCastingMangler{}
tfmr := transform.NewTransformer(t.Type(), flattenMangler, reformatTagMangler, tagCopyingMangler, stringCastingMangler)
// allow aliasing to migrate from one name to another
aliasMangler := &transform.AliasMangler{}
tfmr := transform.NewTransformer(t.Type(), aliasMangler, flattenMangler, reformatTagMangler, tagCopyingMangler, stringCastingMangler)

val, err := tfmr.Translate()
if err != nil {
Expand All @@ -49,11 +49,11 @@ func (e *Source) Value(_ context.Context, t *dials.Type) (reflect.Value, error)
valType := val.Type()
for i := 0; i < val.NumField(); i++ {
sf := valType.Field(i)
envTagVal := sf.Tag.Get(envTagName)
envTagVal := sf.Tag.Get(common.DialsEnvTagName)
if envTagVal == "" {
// dialsenv tag should be populated because dials tag is populated
// after flatten mangler and we copy from dials to dialsenv tag
panic(fmt.Errorf("empty %s tag for field name %s", envTagName, sf.Name))
panic(fmt.Errorf("empty %s tag for field name %s", common.DialsEnvTagName, sf.Name))
}

if e.Prefix != "" {
Expand Down
8 changes: 3 additions & 5 deletions sources/flag/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ var (
_ dials.Source = (*Set)(nil)
)

const dialsFlagTag = "dialsflag"

// NameConfig defines the parameters for separating components of a flag-name
type NameConfig struct {
// FieldNameEncodeCasing is for the field names used by the flatten mangler
Expand Down Expand Up @@ -204,7 +202,7 @@ func (s *Set) parse() error {

func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error {
fm := transform.NewFlattenMangler(common.DialsTagName, s.NameCfg.FieldNameEncodeCasing, s.NameCfg.TagEncodeCasing)
tfmr := transform.NewTransformer(ptyp, fm)
tfmr := transform.NewTransformer(ptyp, &transform.AliasMangler{}, fm)
val, TrnslErr := tfmr.Translate()
if TrnslErr != nil {
return TrnslErr
Expand Down Expand Up @@ -241,7 +239,7 @@ func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error {
// If the field's dialsflag tag is a hyphen (ex: `dialsflag:"-"`),
// don't register the flag. Currently nested fields with "-" tag will
// still be registered
if dft, ok := sf.Tag.Lookup(dialsFlagTag); ok && (dft == "-") {
if dft, ok := sf.Tag.Lookup(common.DialsFlagTagName); ok && (dft == "-") {
continue
}

Expand Down Expand Up @@ -506,7 +504,7 @@ func willOverflow(val, target reflect.Value) bool {
// decoded field name and converting it into kebab case
func (s *Set) mkname(sf reflect.StructField) string {
// use the name from the dialsflag tag for the flag name
if name, ok := sf.Tag.Lookup(dialsFlagTag); ok {
if name, ok := sf.Tag.Lookup(common.DialsFlagTagName); ok {
return name
}
// check if the dials tag is populated (it should be once it goes through
Expand Down
16 changes: 6 additions & 10 deletions sources/pflag/pflag.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,9 @@ var (
)

const (
dialsPFlagTag = "dialspflag"
dialsPFlagShortTag = "dialspflagshort"
// HelpTextTag is the name of the struct tag for flag descriptions
HelpTextTag = "dialsdesc"
// DefaultFlagHelpText is the default help-text for fields with an
// unset dialsdesc tag.
DefaultFlagHelpText = "unset description (`" + HelpTextTag + "` struct tag)"
DefaultFlagHelpText = "unset description (`" + common.DialsHelpTextTag + "` struct tag)"
)

// NameConfig defines the parameters for separating components of a flag-name
Expand Down Expand Up @@ -214,7 +210,7 @@ func (s *Set) parse() error {

func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error {
fm := transform.NewFlattenMangler(common.DialsTagName, s.NameCfg.FieldNameEncodeCasing, s.NameCfg.TagEncodeCasing)
tfmr := transform.NewTransformer(ptyp, fm)
tfmr := transform.NewTransformer(ptyp, &transform.AliasMangler{}, fm)
val, TrnslErr := tfmr.Translate()
if TrnslErr != nil {
return TrnslErr
Expand All @@ -235,7 +231,7 @@ func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error {
for i := 0; i < t.NumField(); i++ {
sf := t.Field(i)
help := DefaultFlagHelpText
if x, ok := sf.Tag.Lookup(HelpTextTag); ok {
if x, ok := sf.Tag.Lookup(common.DialsHelpTextTag); ok {
help = x
}

Expand All @@ -251,7 +247,7 @@ func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error {
// If the field's dialspflag tag is a hyphen (ex: `dialspflag:"-"`),
// don't register the flag. Currently nested fields with "-" tag will
// still be registered
if dpt, ok := sf.Tag.Lookup(dialsPFlagTag); ok && (dpt == "-") {
if dpt, ok := sf.Tag.Lookup(common.DialsPFlagTag); ok && (dpt == "-") {
continue
}

Expand All @@ -267,7 +263,7 @@ func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error {

// get the concrete value of the field from the template
fieldVal := transform.GetField(sf, tmpl)
shorthand, _ := sf.Tag.Lookup(dialsPFlagShortTag)
shorthand, _ := sf.Tag.Lookup(common.DialsPFlagShortTag)
var f interface{}

switch {
Expand Down Expand Up @@ -516,7 +512,7 @@ func stripTypePtr(t reflect.Type) reflect.Type {
// decoded field name and converting it into kebab case
func (s *Set) mkname(sf reflect.StructField) string {
// use the name from the dialspflag tag for the flag name
if name, ok := sf.Tag.Lookup(dialsPFlagTag); ok {
if name, ok := sf.Tag.Lookup(common.DialsPFlagTag); ok {
return name
}
// check if the dials tag is populated (it should be once it goes through
Expand Down
145 changes: 145 additions & 0 deletions transform/alias_mangler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package transform

import (
"fmt"
"reflect"

"github.com/fatih/structtag"
"github.com/vimeo/dials/common"
)

const (
dialsAliasTagSuffix = "alias"
aliasFieldSuffix = "_alias9wr876rw3" // a random string to append to the alias field to avoid collisions
)

// the list of tags that we should search for aliases
var aliasSourceTags = []string{
common.DialsFlagTagName,
common.DialsEnvTagName,
common.DialsFlagTagName,
common.DialsPFlagTag,
common.DialsPFlagShortTag,
}

// AliasMangler manages aliases for dials, dialsenv, dialsflag, and dialspflag
// struct tags to make it possible to migrate from one name to another
// conveniently.
type AliasMangler struct{}

// Mangle implements the Mangler interface. If an alias tag is defined, the
// struct field will be copied with the non-aliased tag set to the alias's
// value.
func (a AliasMangler) Mangle(sf reflect.StructField) ([]reflect.StructField, error) {
originalVals := map[string]string{}
aliasVals := map[string]string{}

sfTags, parseErr := structtag.Parse(string(sf.Tag))
if parseErr != nil {
return nil, fmt.Errorf("error parsing source tags %w", parseErr)
}

anyAliasFound := false
for _, tag := range aliasSourceTags {
if originalVal, getErr := sfTags.Get(tag); getErr == nil {
originalVals[tag] = originalVal.Name
}

if aliasVal, getErr := sfTags.Get(tag + dialsAliasTagSuffix); getErr == nil {
aliasVals[tag] = aliasVal.Name
anyAliasFound = true

// remove the alias tag from the definition
sfTags.Delete(tag + dialsAliasTagSuffix)
}
}

if !anyAliasFound {
// we didn't find any aliases so just get out early
return []reflect.StructField{sf}, nil
}

aliasField := sf
aliasField.Name += aliasFieldSuffix

// now that we've copied it, reset the struct tags on the source field to
// not include the alias tags
sf.Tag = reflect.StructTag(sfTags.String())

tags, parseErr := structtag.Parse(string(aliasField.Tag))
if parseErr != nil {
return nil, fmt.Errorf("error parsing struct tags: %w", parseErr)
}

for _, tag := range aliasSourceTags {
// remove the alias tag so it's not left on the copied StructField
tags.Delete(tag + dialsAliasTagSuffix)

if aliasVals[tag] == "" {
// if the particular flag isn't set at all just move on...
continue
}

newDialsTag := &structtag.Tag{
Key: tag,
Name: aliasVals[tag],
}

if setErr := tags.Set(newDialsTag); setErr != nil {
return nil, fmt.Errorf("error setting new value for dials tag: %w", setErr)
}

// update dialsdesc if there is one
if desc, getErr := tags.Get("dialsdesc"); getErr == nil {
newDesc := &structtag.Tag{
Key: "dialsdesc",
Name: desc.Name + " (alias of `" + originalVals[tag] + "`)",
}
if setErr := tags.Set(newDesc); setErr != nil {
return nil, fmt.Errorf("error setting amended dialsdesc for tag %q: %w", tag, setErr)
}
}
}

// set the new flags on the alias field
aliasField.Tag = reflect.StructTag(tags.String())

return []reflect.StructField{sf, aliasField}, nil
}

// Unmangle implements the Mangler interface and unwinds the alias copying
// operation. Note that if both the source and alias are both set in the
// configuration, an error will be returned.
func (a AliasMangler) Unmangle(sf reflect.StructField, fvs []FieldValueTuple) (reflect.Value, error) {
switch len(fvs) {
case 1:
// if there's only one tuple that means there was no alias, so just
// return...
return fvs[0].Value, nil
case 2:
// two means there's an alias so we should continue on...
default:
return reflect.Value{}, fmt.Errorf("expected 1 or 2 tuples, got %d", len(fvs))
}

if !fvs[0].Value.IsNil() && !fvs[1].Value.IsNil() {
return reflect.Value{}, fmt.Errorf("both alias and original set for field %q", sf.Name)
}

// return the first one that isn't nil
for _, fv := range fvs {
if !fv.Value.IsNil() {
return fv.Value, nil
}
}

// if we made it this far, they were both nil, which is fine -- just return
// one of them.
return fvs[0].Value, nil
}

// ShouldRecurse is called after Mangle for each field so nested struct
// fields get iterated over after any transformation done by Mangle().
func (a AliasMangler) ShouldRecurse(_ reflect.StructField) bool {
return true
}
Loading

0 comments on commit 0ff99d6

Please sign in to comment.