Skip to content

Commit

Permalink
add function importer
Browse files Browse the repository at this point in the history
  • Loading branch information
kvaps committed May 7, 2024
1 parent ff82a29 commit be081b7
Show file tree
Hide file tree
Showing 3 changed files with 394 additions and 0 deletions.
197 changes: 197 additions & 0 deletions pkg/commands/reset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package talos

import (
"context"
"fmt"
"sort"
"strings"

"github.com/siderolabs/gen/maps"
"github.com/spf13/cobra"

"github.com/siderolabs/talos/cmd/talosctl/cmd/common"
"github.com/siderolabs/talos/cmd/talosctl/pkg/talos/action"
"github.com/siderolabs/talos/cmd/talosctl/pkg/talos/helpers"
machineapi "github.com/siderolabs/talos/pkg/machinery/api/machine"
"github.com/siderolabs/talos/pkg/machinery/client"
)

var wipeOptions = map[string]machineapi.ResetRequest_WipeMode{
wipeModeAll: machineapi.ResetRequest_ALL,
wipeModeSystemDisk: machineapi.ResetRequest_SYSTEM_DISK,
wipeModeUserDisks: machineapi.ResetRequest_USER_DISKS,
}

// WipeMode apply, patch, edit config update mode.
type WipeMode machineapi.ResetRequest_WipeMode

const (
wipeModeAll = "all"
wipeModeSystemDisk = "system-disk"
wipeModeUserDisks = "user-disks"
)

func (m WipeMode) String() string {
switch machineapi.ResetRequest_WipeMode(m) {
case machineapi.ResetRequest_ALL:
return wipeModeAll
case machineapi.ResetRequest_SYSTEM_DISK:
return wipeModeSystemDisk
case machineapi.ResetRequest_USER_DISKS:
return wipeModeUserDisks
}

return wipeModeAll
}

// Set implements Flag interface.
func (m *WipeMode) Set(value string) error {
mode, ok := wipeOptions[value]
if !ok {
return fmt.Errorf("possible options are: %s", m.Type())
}

*m = WipeMode(mode)

return nil
}

// Type implements Flag interface.
func (m *WipeMode) Type() string {
options := maps.Keys(wipeOptions)
sort.Strings(options)

return strings.Join(options, ", ")
}

var resetCmdFlags struct {
trackableActionCmdFlags
graceful bool
reboot bool
insecure bool
wipeMode WipeMode
userDisksToWipe []string
systemLabelsToWipe []string
}

// resetCmd represents the reset command.
var resetCmd = &cobra.Command{
Use: "reset",
Short: "Reset a node",
Long: ``,
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
if resetCmdFlags.debug {
resetCmdFlags.wait = true
}

resetRequest := buildResetRequest()

if resetCmdFlags.wait && resetCmdFlags.insecure {
return fmt.Errorf("cannot use --wait and --insecure together")
}

if !resetCmdFlags.wait {
resetNoWait := func(ctx context.Context, c *client.Client) error {
if err := helpers.ClientVersionCheck(ctx, c); err != nil {
return err
}

if err := c.ResetGeneric(ctx, resetRequest); err != nil {
return fmt.Errorf("error executing reset: %s", err)
}

return nil
}

if resetCmdFlags.insecure {
return WithClientMaintenance(nil, resetNoWait)
}

return WithClient(resetNoWait)
}

actionFn := func(ctx context.Context, c *client.Client) (string, error) {
return resetGetActorID(ctx, c, resetRequest)
}

var postCheckFn func(context.Context, *client.Client, string) error

if resetCmdFlags.reboot {
postCheckFn = func(ctx context.Context, c *client.Client, preActionBootID string) error {
err := WithClientMaintenance(nil,
func(ctx context.Context, cli *client.Client) error {
_, err := cli.Disks(ctx)

return err
})

// if we can get into maintenance mode, reset has succeeded
if err == nil {
return nil
}

// try to get the boot ID in the normal mode to see if the node has rebooted
return action.BootIDChangedPostCheckFn(ctx, c, preActionBootID)
}
}

common.SuppressErrors = true

return action.NewTracker(
&GlobalArgs,
action.StopAllServicesEventFn,
actionFn,
action.WithPostCheck(postCheckFn),
action.WithDebug(resetCmdFlags.debug),
action.WithTimeout(resetCmdFlags.timeout),
).Run()
},
}

func buildResetRequest() *machineapi.ResetRequest {
systemPartitionsToWipe := make([]*machineapi.ResetPartitionSpec, 0, len(resetCmdFlags.systemLabelsToWipe))

for _, label := range resetCmdFlags.systemLabelsToWipe {
systemPartitionsToWipe = append(systemPartitionsToWipe, &machineapi.ResetPartitionSpec{
Label: label,
Wipe: true,
})
}

return &machineapi.ResetRequest{
Graceful: resetCmdFlags.graceful,
Reboot: resetCmdFlags.reboot,
UserDisksToWipe: resetCmdFlags.userDisksToWipe,
Mode: machineapi.ResetRequest_WipeMode(resetCmdFlags.wipeMode),
SystemPartitionsToWipe: systemPartitionsToWipe,
}
}

func resetGetActorID(ctx context.Context, c *client.Client, req *machineapi.ResetRequest) (string, error) {
resp, err := c.ResetGenericWithResponse(ctx, req)
if err != nil {
return "", err
}

if len(resp.GetMessages()) == 0 {
return "", fmt.Errorf("no messages returned from action run")
}

return resp.GetMessages()[0].GetActorId(), nil
}

func init() {
resetCmd.Flags().BoolVar(&resetCmdFlags.graceful, "graceful", true, "if true, attempt to cordon/drain node and leave etcd (if applicable)")
resetCmd.Flags().BoolVar(&resetCmdFlags.reboot, "reboot", false, "if true, reboot the node after resetting instead of shutting down")
resetCmd.Flags().BoolVar(&resetCmdFlags.insecure, "insecure", false, "reset using the insecure (encrypted with no auth) maintenance service")
resetCmd.Flags().Var(&resetCmdFlags.wipeMode, "wipe-mode", "disk reset mode")
resetCmd.Flags().StringSliceVar(&resetCmdFlags.userDisksToWipe, "user-disks-to-wipe", nil, "if set, wipes defined devices in the list")
resetCmd.Flags().StringSliceVar(&resetCmdFlags.systemLabelsToWipe, "system-labels-to-wipe", nil, "if set, just wipe selected system disk partitions by label but keep other partitions intact")
resetCmdFlags.addTrackActionFlags(resetCmd)
addCommand(resetCmd)
}
Binary file added tools/fix_imported
Binary file not shown.
197 changes: 197 additions & 0 deletions tools/import_functions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package main

import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"path/filepath"
"strings"
)

func changePackageName(node *ast.File, newPackageName string) {
node.Name = ast.NewIdent(newPackageName)
}

func addFieldToStructDecl(node *ast.File, varName string, fieldType, fieldName string) {
ast.Inspect(node, func(n ast.Node) bool {
decl, ok := n.(*ast.GenDecl)
if !ok || decl.Tok != token.VAR {
return true
}
for _, spec := range decl.Specs {
vs, ok := spec.(*ast.ValueSpec)
if !ok || len(vs.Names) != 1 || vs.Names[0].Name != varName {
continue
}
st, ok := vs.Type.(*ast.StructType)
if !ok {
continue
}
field := &ast.Field{
Names: []*ast.Ident{ast.NewIdent(fieldName)},
Type: ast.NewIdent(fieldType),
}
st.Fields.List = append(st.Fields.List, field)
return false
}
return true
})
}

func prependStmtToInit(node *ast.File, cmdName string) {
ast.Inspect(node, func(n ast.Node) bool {
fn, ok := n.(*ast.FuncDecl)
if ok && fn.Name.Name == "init" {
stmt := &ast.ExprStmt{
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.SelectorExpr{
X: ast.NewIdent(cmdName + "Cmd"),
Sel: ast.NewIdent("Flags()"),
},
Sel: ast.NewIdent("StringSliceVarP"),
},
Args: []ast.Expr{
&ast.UnaryExpr{
Op: token.AND,
X: ast.NewIdent(cmdName + "CmdFlags.configFiles"),
},
ast.NewIdent(`"file"`),
ast.NewIdent(`"f"`),
ast.NewIdent("nil"),
ast.NewIdent(`"specify config files or patches in a YAML file (can specify multiple)"`),
},
},
}
fn.Body.List = append([]ast.Stmt{stmt}, fn.Body.List...)
return false
}
return true
})
}

func insertInitCode(node *ast.File, cmdName, initCode string) {
anonFuncCode := fmt.Sprintf(`func() { %s }`, initCode)

initCodeExpr, err := parser.ParseExpr(anonFuncCode)
if err != nil {
log.Fatalf("Failed to parse init code: %v", err)
}

ast.Inspect(node, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.FuncDecl:
if x.Name.Name == "init" {
if x.Body != nil {
initFunc, ok := initCodeExpr.(*ast.FuncLit)
if !ok {
log.Fatalf("Failed to extract function body from init code expression")
}

x.Body.List = append(initFunc.Body.List, x.Body.List...)
}
}
}
return true
})
}

func processFile(filename string) {
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
if err != nil {
log.Fatalf("Failed to parse file: %v", err)
}

cmdName := strings.TrimSuffix(filepath.Base(filename), ".go")
changePackageName(node, "commands")
addFieldToStructDecl(node, cmdName+"CmdFlags", "[]string", "configFiles")

initCode := fmt.Sprintf(`%sCmd.Flags().StringSliceVarP(&%sCmdFlags.configFiles, "file", "f", nil, "specify config files or patches in a YAML file (can specify multiple)")
%sCmd.PreRunE = func(cmd *cobra.Command, args []string) error {
nodesFromArgs := len(
GlobalArgs.Nodes) > 0
endpointsFromArgs := len(GlobalArgs.Endpoints) >
0
for _, configFile := range %sCmdFlags.configFiles {
if err :=
processModelineAndUpdateGlobals(configFile,
nodesFromArgs,
endpointsFromArgs, false,
); err != nil {
return err
}
}
return nil
}
`, cmdName, cmdName, cmdName, cmdName)

insertInitCode(node, cmdName, initCode)

var buf bytes.Buffer
if err := format.Node(&buf, fset, node); err != nil {
log.Fatalf("Failed to format the AST: %v", err)
}

if err := ioutil.WriteFile(filename, buf.Bytes(), 0644); err != nil {
log.Fatalf("Failed to write the modified file: %v", err)
}

log.Printf("File %s updated successfully.", filename)
}

func main() {
talosVersion := flag.String("talos-version", "main", "the desired Talos version (branch or tag)")
flag.Parse()
url := fmt.Sprintf("https://github.com/siderolabs/talos/raw/%s/cmd/talosctl/cmd/talos/", *talosVersion)

args := flag.Args()
if len(args) == 0 {
fmt.Println("Please provide commands to import")
return
}

for _, cmd := range args {
srcName := cmd + ".go"
dstName := "pkg/commands/imported_" + srcName

err := downloadFile(srcName, dstName, url)
if err != nil {
log.Fatalf("Error downloading file: %v", err)
}

log.Printf("File %s succefully downloaded to %s", srcName, dstName)
processFile(dstName)
}
}

func downloadFile(srcName, dstName string, url string) error {
resp, err := http.Get(url + "/" + srcName)
if err != nil {
return err
}
defer resp.Body.Close()

file, err := os.Create(dstName)
if err != nil {
return err
}
defer file.Close()

_, err = io.Copy(file, resp.Body)
if err != nil {
return err
}

return nil
}

0 comments on commit be081b7

Please sign in to comment.