Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes for home assistant service calling #23

Merged
merged 4 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 3 additions & 45 deletions cmd/api/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Fn struct {
Description string
MinArgs uint
MaxArgs uint
Syntax string
Call func(context.Context, *tablewriter.Writer, []string) error
}

Expand All @@ -39,51 +40,8 @@ func (c *Cmd) Get(name string) *Fn {
}

func (fn *Fn) CheckArgs(args []string) error {
// Check number of arguments
if fn.MinArgs != 0 && uint(len(args)) < fn.MinArgs {
return fmt.Errorf("not enough arguments for %q (expected >= %d)", fn.Name, fn.MinArgs)
}
if fn.MaxArgs != 0 && uint(len(args)) > fn.MaxArgs {
return fmt.Errorf("too many arguments for %q (expected <= %d)", fn.Name, fn.MaxArgs)
if (fn.MinArgs != 0 && uint(len(args)) < fn.MinArgs) || (fn.MaxArgs != 0 && uint(len(args)) > fn.MaxArgs) {
return fmt.Errorf("syntax error: %s %s", fn.Name, fn.Syntax)
}
return nil
}

/*
if fn == nil {
return nil, fmt.Errorf("unknown command %q", name)
}

return c.getFn(name), nil
// Get the command function
var fn *Fn
var nargs uint
var out []string
if len(args) == 0 {
fn = c.getFn("")
} else {
fn = c.getFn(args[0])
nargs = uint(len(args) - 1)
out = args[1:]
}
if fn == nil {
// No arguments and no default command
return nil, nil, nil
}

// Check number of arguments
name := fn.Name
if name == "" {
name = c.Name
}
if fn.MinArgs != 0 && nargs < fn.MinArgs {
return nil, nil, fmt.Errorf("not enough arguments for %q", name)
}
if fn.MaxArgs != 0 && nargs > fn.MaxArgs {
return nil, nil, fmt.Errorf("too many arguments for %q", name)
}

// Return the command
return fn, out, nil
}
*/
26 changes: 18 additions & 8 deletions cmd/api/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (flags *Flags) Parse(args []string) (*Fn, []string, error) {
// Parse command line
err := flags.FlagSet.Parse(args)

// If there is a version argument, print the version and exit
// Check for global commands
if flags.NArg() == 1 {
switch flags.Arg(0) {
case "version":
Expand All @@ -97,17 +97,27 @@ func (flags *Flags) Parse(args []string) (*Fn, []string, error) {
// If the name of the command is the same as the name of the application
flags.cmd = cmd
flags.root = cmd.Name
flags.fn = ""
flags.args = flags.Args()
if len(flags.Args()) > 0 {
flags.fn = flags.Arg(0)
if len(flags.Args()) > 1 {
flags.args = flags.Args()[1:]
}
}
} else if flags.NArg() > 0 {
if cmd := flags.getCommandSet(flags.Arg(0)); cmd != nil {
flags.cmd = cmd
flags.root = strings.Join([]string{flags.Name(), cmd.Name}, " ")
flags.fn = flags.Arg(1)
flags.args = flags.Args()[1:]
if len(flags.Args()) > 1 {
flags.args = flags.Args()[2:]
}
}
}

if flags.GetBool("debug") {
fmt.Fprintf(os.Stderr, "Function: %q Args: %q\n", flags.fn, flags.args)
}

// Print usage
if err != nil {
if err != flag.ErrHelp {
Expand All @@ -117,7 +127,7 @@ func (flags *Flags) Parse(args []string) (*Fn, []string, error) {
}
return nil, nil, err
} else if flags.cmd == nil {
fmt.Fprintln(os.Stderr, "Unknown command, try -help")
fmt.Fprintf(os.Stderr, "Unknown command, try \"%s -help\"\n", flags.Name())
return nil, nil, ErrHelp
}

Expand All @@ -140,7 +150,7 @@ func (flags *Flags) Parse(args []string) (*Fn, []string, error) {
// Set the function to call
fn := flags.cmd.Get(flags.fn)
if fn == nil {
fmt.Fprintf(os.Stderr, "Unknown command %q, try -help\n", flags.fn)
fmt.Fprintf(os.Stderr, "Unknown command, try \"%s -help\"\n", flags.Name())
return nil, nil, ErrHelp
}

Expand All @@ -151,7 +161,7 @@ func (flags *Flags) Parse(args []string) (*Fn, []string, error) {
}

// Return success
return fn, args, nil
return fn, flags.args, nil
}

// Get returns the value of a flag, and returns true if the flag exists
Expand Down Expand Up @@ -252,7 +262,7 @@ func (flags *Flags) PrintCommandUsage(cmd *Cmd) {
// Help for command sets
fmt.Fprintln(w, "Commands:")
for _, fn := range cmd.Fn {
fmt.Fprintln(w, " ", flags.root, fn.Name)
fmt.Fprintln(w, " ", flags.root, fn.Name, fn.Syntax)
fmt.Fprintln(w, " ", fn.Description)
fmt.Fprintln(w, "")
}
Expand Down
117 changes: 79 additions & 38 deletions cmd/api/homeassistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package main

import (
"context"
"fmt"
"slices"
"strings"
"time"

Expand All @@ -16,16 +14,19 @@ import (
// TYPES

type haEntity struct {
Id string `json:"entity_id"`
Id string `json:"entity_id,width:40"`
Name string `json:"name,omitempty"`
Class string `json:"class,omitempty"`
Domain string `json:"domain,omitempty"`
State string `json:"state,omitempty"`
Attributes map[string]interface{} `json:"attributes,omitempty,wrap"`
UpdatedAt time.Time `json:"last_updated,omitempty"`
UpdatedAt time.Time `json:"last_updated,omitempty,width:34"`
ChangedAt time.Time `json:"last_changed,omitempty,width:34"`
}

type haClass struct {
Class string `json:"class,omitempty"`
type haDomain struct {
Name string `json:"domain"`
Services string `json:"services,omitempty"`
}

///////////////////////////////////////////////////////////////////////////////
Expand All @@ -49,8 +50,10 @@ func haRegister(flags *Flags) {
Description: "Information from home assistant",
Parse: haParse,
Fn: []Fn{
{Name: "classes", Call: haClasses, Description: "Return entity classes"},
{Name: "states", Call: haStates, Description: "Return entity states"},
{Name: "domains", Call: haDomains, Description: "Enumerate entity domains"},
{Name: "states", Call: haStates, Description: "Show current entity states", MaxArgs: 1, Syntax: "(<name>)"},
{Name: "services", Call: haServices, Description: "Show services for an entity", MinArgs: 1, MaxArgs: 1, Syntax: "<entity>"},
{Name: "call", Call: haCall, Description: "Call a service for an entity", MinArgs: 2, MaxArgs: 2, Syntax: "<call> <entity>"},
},
})
}
Expand All @@ -71,14 +74,25 @@ func haParse(flags *Flags, opts ...client.ClientOpt) error {
// METHODS

func haStates(_ context.Context, w *tablewriter.Writer, args []string) error {
if states, err := haGetStates(args); err != nil {
var result []haEntity
states, err := haGetStates(nil)
if err != nil {
return err
} else {
return w.Write(states)
}

for _, state := range states {
if len(args) == 1 {
if !haMatchString(args[0], state.Name, state.Id) {
continue
}

}
result = append(result, state)
}
return w.Write(result)
}

func haClasses(_ context.Context, w *tablewriter.Writer, args []string) error {
func haDomains(_ context.Context, w *tablewriter.Writer, args []string) error {
states, err := haGetStates(nil)
if err != nil {
return err
Expand All @@ -89,17 +103,52 @@ func haClasses(_ context.Context, w *tablewriter.Writer, args []string) error {
classes[state.Class] = true
}

result := []haClass{}
result := []haDomain{}
for c := range classes {
result = append(result, haClass{Class: c})
result = append(result, haDomain{
Name: c,
})
}
return w.Write(result)
}

func haServices(_ context.Context, w *tablewriter.Writer, args []string) error {
service, err := haClient.State(args[0])
if err != nil {
return err
}
services, err := haClient.Services(service.Domain())
if err != nil {
return err
}
return w.Write(services)
}

func haCall(_ context.Context, w *tablewriter.Writer, args []string) error {
service := args[0]
entity := args[1]
states, err := haClient.Call(service, entity)
if err != nil {
return err
}
return w.Write(states)
}

///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS

func haGetStates(classes []string) ([]haEntity, error) {
func haMatchString(q string, values ...string) bool {
q = strings.ToLower(q)
for _, v := range values {
v = strings.ToLower(v)
if strings.Contains(v, q) {
return true
}
}
return false
}

func haGetStates(domains []string) ([]haEntity, error) {
var result []haEntity

// Get states from the remote service
Expand All @@ -112,37 +161,29 @@ func haGetStates(classes []string) ([]haEntity, error) {
for _, state := range states {
entity := haEntity{
Id: state.Entity,
State: state.State,
Name: state.Name(),
Domain: state.Domain(),
Class: state.Class(),
State: state.Value(),
Attributes: state.Attributes,
UpdatedAt: state.LastChanged,
UpdatedAt: state.LastUpdated,
ChangedAt: state.LastChanged,
}

// Ignore entities without state
if entity.State == "" || entity.State == "unknown" || entity.State == "unavailable" {
// Ignore any fields where the state is empty
if entity.State == "" {
continue
}

// Set entity type and name from entity id
parts := strings.SplitN(entity.Id, ".", 2)
if len(parts) >= 2 {
entity.Class = strings.ToLower(parts[0])
entity.Name = parts[1]
}

// Set entity type from device class
if t, exists := state.Attributes["device_class"]; exists {
entity.Class = fmt.Sprint(t)
// Add unit of measurement
if unit := state.UnitOfMeasurement(); unit != "" {
entity.State += " " + unit
}

// Filter classes
if len(classes) > 0 && !slices.Contains(classes, entity.Class) {
continue
}

// Set entity name from attributes
if name, exists := state.Attributes["friendly_name"]; exists {
entity.Name = fmt.Sprint(name)
}
// Filter domains
//if len(domains) > 0 && !slices.Contains(domains, entity.Domain) {
// continue
//}

// Append results
result = append(result, entity)
Expand Down
8 changes: 4 additions & 4 deletions pkg/homeassistant/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ func Test_client_001(t *testing.T) {
// ENVIRONMENT

func GetApiKey(t *testing.T) string {
key := os.Getenv("HA_API_KEY")
key := os.Getenv("HA_TOKEN")
if key == "" {
t.Skip("HA_API_KEY not set")
t.Skip("HA_TOKEN not set")
t.SkipNow()
}
return key
}

func GetEndPoint(t *testing.T) string {
key := os.Getenv("HA_API_URL")
key := os.Getenv("HA_ENDPOINT")
if key == "" {
t.Skip("HA_API_URL not set")
t.Skip("HA_ENDPOINT not set")
t.SkipNow()
}
return key
Expand Down
Loading
Loading