diff --git a/internal/slurm/slurm_shares.go b/internal/slurm/slurm_shares.go new file mode 100644 index 0000000..81638c3 --- /dev/null +++ b/internal/slurm/slurm_shares.go @@ -0,0 +1,44 @@ +package slurm + +import ( + "context" + "net/url" +) + +// GetSharesOptions specifies optional parameters for GetShares. +type GetSharesOptions struct { + Accounts *string `url:"accounts,omitempty"` + Users *string `url:"users,omitempty"` +} + +// GetShares retrieves fairshare shares information. +func (s *SharesService) GetShares(ctx context.Context, opts *GetSharesOptions) (*OpenapiSharesResp, *Response, error) { + path := "slurm/v0.0.40/shares" + req, err := s.client.NewRequest("GET", path, nil) + if err != nil { + return nil, nil, err + } + + if opts != nil { + u, parseErr := url.Parse(req.URL.String()) + if parseErr != nil { + return nil, nil, parseErr + } + q := u.Query() + if opts.Accounts != nil { + q.Set("accounts", *opts.Accounts) + } + if opts.Users != nil { + q.Set("users", *opts.Users) + } + u.RawQuery = q.Encode() + req.URL = u + } + + var result OpenapiSharesResp + resp, err := s.client.Do(ctx, req, &result) + if err != nil { + return nil, resp, err + } + return &result, resp, nil +} diff --git a/internal/slurm/slurm_shares_test.go b/internal/slurm/slurm_shares_test.go new file mode 100644 index 0000000..a418d65 --- /dev/null +++ b/internal/slurm/slurm_shares_test.go @@ -0,0 +1,84 @@ +package slurm + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestSharesService_GetShares(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/slurm/v0.0.40/shares", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + fmt.Fprint(w, `{"shares": {"shares": [], "total_shares": 100}}`) + }) + server := httptest.NewServer(mux) + defer server.Close() + + client, _ := NewClient(server.URL, nil) + resp, _, err := client.Shares.GetShares(context.Background(), nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if resp.Shares == nil { + t.Fatal("expected non-nil shares") + } + if resp.Shares.TotalShares == nil || *resp.Shares.TotalShares != 100 { + t.Errorf("expected total_shares=100, got %v", resp.Shares.TotalShares) + } +} + +func TestSharesService_GetShares_WithOptions(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/slurm/v0.0.40/shares", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + q := r.URL.Query() + if q.Get("accounts") != "acc1" { + t.Errorf("expected accounts=acc1, got %s", q.Get("accounts")) + } + if q.Get("users") != "user1" { + t.Errorf("expected users=user1, got %s", q.Get("users")) + } + fmt.Fprint(w, `{"shares": {"shares": [], "total_shares": 50}}`) + }) + server := httptest.NewServer(mux) + defer server.Close() + + client, _ := NewClient(server.URL, nil) + opts := &GetSharesOptions{ + Accounts: Ptr("acc1"), + Users: Ptr("user1"), + } + resp, _, err := client.Shares.GetShares(context.Background(), opts) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } +} + +func TestSharesService_GetShares_Error(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/slurm/v0.0.40/shares", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, `{"errors": [{"error": "internal error"}]}`) + }) + server := httptest.NewServer(mux) + defer server.Close() + + client, _ := NewClient(server.URL, nil) + _, _, err := client.Shares.GetShares(context.Background(), nil) + if err == nil { + t.Fatal("expected error for 500 response") + } + if !strings.Contains(err.Error(), "500") { + t.Errorf("expected error to contain 500, got %v", err) + } +} diff --git a/internal/slurm/types_shares.go b/internal/slurm/types_shares.go new file mode 100644 index 0000000..65518d3 --- /dev/null +++ b/internal/slurm/types_shares.go @@ -0,0 +1,70 @@ +package slurm + +// --------------------------------------------------------------------------- +// Shares — types for the Slurm fairshare/shares API (v0.0.40) +// --------------------------------------------------------------------------- + +// SharesFloat128Tres represents a float128 TRES entry (v0.0.40_shares_float128_tres). +type SharesFloat128Tres struct { + Name *string `json:"name,omitempty"` + Value *float64 `json:"value,omitempty"` +} + +// SharesFloat128TresList is an array of SharesFloat128Tres (v0.0.40_shares_float128_tres_list). +type SharesFloat128TresList []SharesFloat128Tres + +// SharesUint64Tres represents a uint64 TRES entry (v0.0.40_shares_uint64_tres). +type SharesUint64Tres struct { + Name *string `json:"name,omitempty"` + Value *Uint64NoVal `json:"value,omitempty"` +} + +// SharesUint64TresList is an array of SharesUint64Tres (v0.0.40_shares_uint64_tres_list). +type SharesUint64TresList []SharesUint64Tres + +// AssocSharesObjTres holds TRES usage/limit breakdowns within an association share. +type AssocSharesObjTres struct { + RunSeconds *SharesUint64TresList `json:"run_seconds,omitempty"` + GroupMinutes *SharesUint64TresList `json:"group_minutes,omitempty"` + Usage *SharesFloat128TresList `json:"usage,omitempty"` +} + +// AssocSharesObjFairshare holds fairshare factor and level. +type AssocSharesObjFairshare struct { + Factor *float64 `json:"factor,omitempty"` + Level *float64 `json:"level,omitempty"` +} + +// AssocSharesObjWrap represents an association shares object (v0.0.40_assoc_shares_obj_wrap). +type AssocSharesObjWrap struct { + ID *int32 `json:"id,omitempty"` + Cluster *string `json:"cluster,omitempty"` + Name *string `json:"name,omitempty"` + Parent *string `json:"parent,omitempty"` + Partition *string `json:"partition,omitempty"` + SharesNormalized *Float64NoVal `json:"shares_normalized,omitempty"` + Shares *Uint32NoVal `json:"shares,omitempty"` + Tres *AssocSharesObjTres `json:"tres,omitempty"` + EffectiveUsage *float64 `json:"effective_usage,omitempty"` + UsageNormalized *Float64NoVal `json:"usage_normalized,omitempty"` + Usage *int64 `json:"usage,omitempty"` + Fairshare *AssocSharesObjFairshare `json:"fairshare,omitempty"` + Type []string `json:"type,omitempty"` +} + +// AssocSharesObjList is an array of AssocSharesObjWrap (v0.0.40_assoc_shares_obj_list). +type AssocSharesObjList []AssocSharesObjWrap + +// SharesRespMsg is the shares response message (v0.0.40_shares_resp_msg). +type SharesRespMsg struct { + Shares *AssocSharesObjList `json:"shares,omitempty"` + TotalShares *int64 `json:"total_shares,omitempty"` +} + +// OpenapiSharesResp is the top-level shares API response (v0.0.40_openapi_shares_resp). +type OpenapiSharesResp struct { + Shares *SharesRespMsg `json:"shares,omitempty"` + Meta *OpenapiMeta `json:"meta,omitempty"` + Errors OpenapiErrors `json:"errors,omitempty"` + Warnings OpenapiWarnings `json:"warnings,omitempty"` +} diff --git a/internal/slurm/types_shares_test.go b/internal/slurm/types_shares_test.go new file mode 100644 index 0000000..cff3257 --- /dev/null +++ b/internal/slurm/types_shares_test.go @@ -0,0 +1,271 @@ +package slurm + +import ( + "encoding/json" + "testing" +) + +func TestSharesFloat128TresRoundTrip(t *testing.T) { + orig := SharesFloat128Tres{ + Name: Ptr("cpu"), + Value: Ptr(1234.56), + } + data, err := json.Marshal(orig) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var decoded SharesFloat128Tres + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if *decoded.Name != *orig.Name { + t.Errorf("Name: got %s, want %s", *decoded.Name, *orig.Name) + } + if *decoded.Value != *orig.Value { + t.Errorf("Value: got %f, want %f", *decoded.Value, *orig.Value) + } +} + +func TestSharesUint64TresRoundTrip(t *testing.T) { + orig := SharesUint64Tres{ + Name: Ptr("mem"), + Value: &Uint64NoVal{ + Set: Ptr(true), + Number: Ptr(int64(4096)), + }, + } + data, err := json.Marshal(orig) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var decoded SharesUint64Tres + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if *decoded.Name != *orig.Name { + t.Errorf("Name: got %s, want %s", *decoded.Name, *orig.Name) + } + if *decoded.Value.Number != *orig.Value.Number { + t.Errorf("Value.Number: got %d, want %d", *decoded.Value.Number, *orig.Value.Number) + } +} + +func TestSharesRespMsgRoundTrip(t *testing.T) { + orig := SharesRespMsg{ + Shares: &AssocSharesObjList{ + { + ID: Ptr(int32(1)), + Cluster: Ptr("cluster1"), + Name: Ptr("root"), + Parent: Ptr(""), + Shares: &Uint32NoVal{Set: Ptr(true), Number: Ptr(int64(100))}, + Type: []string{"ASSOCIATION"}, + Fairshare: &AssocSharesObjFairshare{ + Factor: Ptr(0.5), + Level: Ptr(0.5), + }, + Tres: &AssocSharesObjTres{ + Usage: &SharesFloat128TresList{ + {Name: Ptr("cpu"), Value: Ptr(100000.0)}, + }, + RunSeconds: &SharesUint64TresList{ + {Name: Ptr("cpu"), Value: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(3600))}}, + }, + }, + SharesNormalized: &Float64NoVal{Set: Ptr(true), Number: Ptr(1.0)}, + EffectiveUsage: Ptr(0.75), + UsageNormalized: &Float64NoVal{Set: Ptr(true), Number: Ptr(0.75)}, + Usage: Ptr(int64(50000)), + }, + }, + TotalShares: Ptr(int64(100)), + } + data, err := json.Marshal(orig) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var decoded SharesRespMsg + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if *decoded.TotalShares != *orig.TotalShares { + t.Errorf("TotalShares: got %d, want %d", *decoded.TotalShares, *orig.TotalShares) + } + if len(*decoded.Shares) != len(*orig.Shares) { + t.Fatalf("Shares length: got %d, want %d", len(*decoded.Shares), len(*orig.Shares)) + } + got := (*decoded.Shares)[0] + want := (*orig.Shares)[0] + if *got.Name != *want.Name { + t.Errorf("Shares[0].Name: got %s, want %s", *got.Name, *want.Name) + } + if *got.Cluster != *want.Cluster { + t.Errorf("Shares[0].Cluster: got %s, want %s", *got.Cluster, *want.Cluster) + } + if *got.Shares.Number != *want.Shares.Number { + t.Errorf("Shares[0].Shares.Number: got %d, want %d", *got.Shares.Number, *want.Shares.Number) + } + if len(got.Type) != 1 || got.Type[0] != "ASSOCIATION" { + t.Errorf("Shares[0].Type: got %v", got.Type) + } + if got.Fairshare == nil || *got.Fairshare.Factor != 0.5 { + t.Errorf("Shares[0].Fairshare.Factor: got %v", got.Fairshare) + } + if got.Tres == nil || got.Tres.Usage == nil || len(*got.Tres.Usage) != 1 { + t.Fatal("Shares[0].Tres.Usage missing or wrong length") + } + if *(*got.Tres.Usage)[0].Name != "cpu" { + t.Errorf("Shares[0].Tres.Usage[0].Name: got %s", *(*got.Tres.Usage)[0].Name) + } +} + +func TestAssocSharesObjWrapRoundTrip(t *testing.T) { + orig := AssocSharesObjWrap{ + ID: Ptr(int32(42)), + Cluster: Ptr("test-cluster"), + Name: Ptr("alice"), + Parent: Ptr("team-a"), + Partition: Ptr("gpu"), + SharesNormalized: &Float64NoVal{Set: Ptr(true), Number: Ptr(0.25)}, + Shares: &Uint32NoVal{Set: Ptr(true), Number: Ptr(int64(50))}, + Tres: &AssocSharesObjTres{ + RunSeconds: &SharesUint64TresList{ + {Name: Ptr("cpu"), Value: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(7200))}}, + }, + GroupMinutes: &SharesUint64TresList{ + {Name: Ptr("cpu"), Value: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(10080))}}, + }, + Usage: &SharesFloat128TresList{ + {Name: Ptr("cpu"), Value: Ptr(50000.5)}, + {Name: Ptr("mem"), Value: Ptr(20000.0)}, + }, + }, + EffectiveUsage: Ptr(0.33), + UsageNormalized: &Float64NoVal{Set: Ptr(true), Number: Ptr(0.33)}, + Usage: Ptr(int64(30000)), + Fairshare: &AssocSharesObjFairshare{ + Factor: Ptr(0.85), + Level: Ptr(0.7), + }, + Type: []string{"USER"}, + } + data, err := json.Marshal(orig) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var decoded AssocSharesObjWrap + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if *decoded.ID != *orig.ID { + t.Errorf("ID: got %d, want %d", *decoded.ID, *orig.ID) + } + if *decoded.Name != *orig.Name { + t.Errorf("Name: got %s, want %s", *decoded.Name, *orig.Name) + } + if *decoded.Parent != *orig.Parent { + t.Errorf("Parent: got %s, want %s", *decoded.Parent, *orig.Parent) + } + if *decoded.Partition != *orig.Partition { + t.Errorf("Partition: got %s, want %s", *decoded.Partition, *orig.Partition) + } + if *decoded.Shares.Number != *orig.Shares.Number { + t.Errorf("Shares.Number: got %d, want %d", *decoded.Shares.Number, *orig.Shares.Number) + } + if *decoded.Fairshare.Factor != *orig.Fairshare.Factor { + t.Errorf("Fairshare.Factor: got %f, want %f", *decoded.Fairshare.Factor, *orig.Fairshare.Factor) + } + if len(decoded.Type) != 1 || decoded.Type[0] != "USER" { + t.Errorf("Type: got %v", decoded.Type) + } + if decoded.Tres == nil { + t.Fatal("Tres is nil") + } + if len(*decoded.Tres.Usage) != 2 { + t.Errorf("Tres.Usage length: got %d, want 2", len(*decoded.Tres.Usage)) + } + if len(*decoded.Tres.RunSeconds) != 1 { + t.Errorf("Tres.RunSeconds length: got %d, want 1", len(*decoded.Tres.RunSeconds)) + } +} + +func TestOpenapiSharesRespRoundTrip(t *testing.T) { + orig := OpenapiSharesResp{ + Shares: &SharesRespMsg{ + Shares: &AssocSharesObjList{ + { + ID: Ptr(int32(1)), + Name: Ptr("root"), + Cluster: Ptr("cluster1"), + Shares: &Uint32NoVal{Set: Ptr(true), Number: Ptr(int64(100))}, + Fairshare: &AssocSharesObjFairshare{Factor: Ptr(1.0)}, + Type: []string{"ASSOCIATION"}, + }, + }, + TotalShares: Ptr(int64(100)), + }, + Meta: &OpenapiMeta{ + Slurm: &MetaSlurm{ + Version: &MetaSlurmVersion{Major: Ptr("24"), Minor: Ptr("05"), Micro: Ptr("5")}, + Release: Ptr("24.05.5"), + }, + }, + Errors: OpenapiErrors{}, + Warnings: OpenapiWarnings{}, + } + data, err := json.Marshal(orig) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var decoded OpenapiSharesResp + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if decoded.Shares == nil { + t.Fatal("Shares is nil") + } + if *decoded.Shares.TotalShares != *orig.Shares.TotalShares { + t.Errorf("Shares.TotalShares: got %d, want %d", *decoded.Shares.TotalShares, *orig.Shares.TotalShares) + } + if decoded.Meta == nil || decoded.Meta.Slurm == nil { + t.Fatal("Meta.Slurm is nil") + } + if *decoded.Meta.Slurm.Release != *orig.Meta.Slurm.Release { + t.Errorf("Meta.Slurm.Release: got %s, want %s", *decoded.Meta.Slurm.Release, *orig.Meta.Slurm.Release) + } + if len(*decoded.Shares.Shares) != 1 { + t.Fatalf("Shares.Shares length: got %d, want 1", len(*decoded.Shares.Shares)) + } + if *(*decoded.Shares.Shares)[0].Name != "root" { + t.Errorf("Shares.Shares[0].Name: got %s", *(*decoded.Shares.Shares)[0].Name) + } +} + +func TestSharesEmptyRoundTrip(t *testing.T) { + orig := SharesFloat128Tres{} + data, err := json.Marshal(orig) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if string(data) != "{}" { + t.Errorf("empty SharesFloat128Tres: got %s, want {}", data) + } + + origUint := SharesUint64Tres{} + data, err = json.Marshal(origUint) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if string(data) != "{}" { + t.Errorf("empty SharesUint64Tres: got %s, want {}", data) + } + + origWrap := AssocSharesObjWrap{} + data, err = json.Marshal(origWrap) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if string(data) != "{}" { + t.Errorf("empty AssocSharesObjWrap: got %s, want {}", data) + } +}