feat: 添加 Shares 公平共享类型和 SharesService
包含 SharesRespMsg、AssocSharesObjWrap、TRES 明细等类型。SharesService 提供 GetShares 方法,支持按 accounts 和 users 查询。 Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
44
internal/slurm/slurm_shares.go
Normal file
44
internal/slurm/slurm_shares.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package slurm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// GetSharesOptions specifies optional parameters for GetShares.
|
||||
type GetSharesOptions struct {
|
||||
Accounts *string `url:"accounts,omitempty"`
|
||||
Users *string `url:"users,omitempty"`
|
||||
}
|
||||
|
||||
// GetShares retrieves fairshare shares information.
|
||||
func (s *SharesService) GetShares(ctx context.Context, opts *GetSharesOptions) (*OpenapiSharesResp, *Response, error) {
|
||||
path := "slurm/v0.0.40/shares"
|
||||
req, err := s.client.NewRequest("GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if opts != nil {
|
||||
u, parseErr := url.Parse(req.URL.String())
|
||||
if parseErr != nil {
|
||||
return nil, nil, parseErr
|
||||
}
|
||||
q := u.Query()
|
||||
if opts.Accounts != nil {
|
||||
q.Set("accounts", *opts.Accounts)
|
||||
}
|
||||
if opts.Users != nil {
|
||||
q.Set("users", *opts.Users)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
req.URL = u
|
||||
}
|
||||
|
||||
var result OpenapiSharesResp
|
||||
resp, err := s.client.Do(ctx, req, &result)
|
||||
if err != nil {
|
||||
return nil, resp, err
|
||||
}
|
||||
return &result, resp, nil
|
||||
}
|
||||
84
internal/slurm/slurm_shares_test.go
Normal file
84
internal/slurm/slurm_shares_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package slurm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSharesService_GetShares(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/slurm/v0.0.40/shares", func(w http.ResponseWriter, r *http.Request) {
|
||||
testMethod(t, r, "GET")
|
||||
fmt.Fprint(w, `{"shares": {"shares": [], "total_shares": 100}}`)
|
||||
})
|
||||
server := httptest.NewServer(mux)
|
||||
defer server.Close()
|
||||
|
||||
client, _ := NewClient(server.URL, nil)
|
||||
resp, _, err := client.Shares.GetShares(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected non-nil response")
|
||||
}
|
||||
if resp.Shares == nil {
|
||||
t.Fatal("expected non-nil shares")
|
||||
}
|
||||
if resp.Shares.TotalShares == nil || *resp.Shares.TotalShares != 100 {
|
||||
t.Errorf("expected total_shares=100, got %v", resp.Shares.TotalShares)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharesService_GetShares_WithOptions(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/slurm/v0.0.40/shares", func(w http.ResponseWriter, r *http.Request) {
|
||||
testMethod(t, r, "GET")
|
||||
q := r.URL.Query()
|
||||
if q.Get("accounts") != "acc1" {
|
||||
t.Errorf("expected accounts=acc1, got %s", q.Get("accounts"))
|
||||
}
|
||||
if q.Get("users") != "user1" {
|
||||
t.Errorf("expected users=user1, got %s", q.Get("users"))
|
||||
}
|
||||
fmt.Fprint(w, `{"shares": {"shares": [], "total_shares": 50}}`)
|
||||
})
|
||||
server := httptest.NewServer(mux)
|
||||
defer server.Close()
|
||||
|
||||
client, _ := NewClient(server.URL, nil)
|
||||
opts := &GetSharesOptions{
|
||||
Accounts: Ptr("acc1"),
|
||||
Users: Ptr("user1"),
|
||||
}
|
||||
resp, _, err := client.Shares.GetShares(context.Background(), opts)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected non-nil response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharesService_GetShares_Error(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/slurm/v0.0.40/shares", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
fmt.Fprint(w, `{"errors": [{"error": "internal error"}]}`)
|
||||
})
|
||||
server := httptest.NewServer(mux)
|
||||
defer server.Close()
|
||||
|
||||
client, _ := NewClient(server.URL, nil)
|
||||
_, _, err := client.Shares.GetShares(context.Background(), nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 500 response")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "500") {
|
||||
t.Errorf("expected error to contain 500, got %v", err)
|
||||
}
|
||||
}
|
||||
70
internal/slurm/types_shares.go
Normal file
70
internal/slurm/types_shares.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package slurm
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shares — types for the Slurm fairshare/shares API (v0.0.40)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SharesFloat128Tres represents a float128 TRES entry (v0.0.40_shares_float128_tres).
|
||||
type SharesFloat128Tres struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Value *float64 `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
// SharesFloat128TresList is an array of SharesFloat128Tres (v0.0.40_shares_float128_tres_list).
|
||||
type SharesFloat128TresList []SharesFloat128Tres
|
||||
|
||||
// SharesUint64Tres represents a uint64 TRES entry (v0.0.40_shares_uint64_tres).
|
||||
type SharesUint64Tres struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Value *Uint64NoVal `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
// SharesUint64TresList is an array of SharesUint64Tres (v0.0.40_shares_uint64_tres_list).
|
||||
type SharesUint64TresList []SharesUint64Tres
|
||||
|
||||
// AssocSharesObjTres holds TRES usage/limit breakdowns within an association share.
|
||||
type AssocSharesObjTres struct {
|
||||
RunSeconds *SharesUint64TresList `json:"run_seconds,omitempty"`
|
||||
GroupMinutes *SharesUint64TresList `json:"group_minutes,omitempty"`
|
||||
Usage *SharesFloat128TresList `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// AssocSharesObjFairshare holds fairshare factor and level.
|
||||
type AssocSharesObjFairshare struct {
|
||||
Factor *float64 `json:"factor,omitempty"`
|
||||
Level *float64 `json:"level,omitempty"`
|
||||
}
|
||||
|
||||
// AssocSharesObjWrap represents an association shares object (v0.0.40_assoc_shares_obj_wrap).
|
||||
type AssocSharesObjWrap struct {
|
||||
ID *int32 `json:"id,omitempty"`
|
||||
Cluster *string `json:"cluster,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Parent *string `json:"parent,omitempty"`
|
||||
Partition *string `json:"partition,omitempty"`
|
||||
SharesNormalized *Float64NoVal `json:"shares_normalized,omitempty"`
|
||||
Shares *Uint32NoVal `json:"shares,omitempty"`
|
||||
Tres *AssocSharesObjTres `json:"tres,omitempty"`
|
||||
EffectiveUsage *float64 `json:"effective_usage,omitempty"`
|
||||
UsageNormalized *Float64NoVal `json:"usage_normalized,omitempty"`
|
||||
Usage *int64 `json:"usage,omitempty"`
|
||||
Fairshare *AssocSharesObjFairshare `json:"fairshare,omitempty"`
|
||||
Type []string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
// AssocSharesObjList is an array of AssocSharesObjWrap (v0.0.40_assoc_shares_obj_list).
|
||||
type AssocSharesObjList []AssocSharesObjWrap
|
||||
|
||||
// SharesRespMsg is the shares response message (v0.0.40_shares_resp_msg).
|
||||
type SharesRespMsg struct {
|
||||
Shares *AssocSharesObjList `json:"shares,omitempty"`
|
||||
TotalShares *int64 `json:"total_shares,omitempty"`
|
||||
}
|
||||
|
||||
// OpenapiSharesResp is the top-level shares API response (v0.0.40_openapi_shares_resp).
|
||||
type OpenapiSharesResp struct {
|
||||
Shares *SharesRespMsg `json:"shares,omitempty"`
|
||||
Meta *OpenapiMeta `json:"meta,omitempty"`
|
||||
Errors OpenapiErrors `json:"errors,omitempty"`
|
||||
Warnings OpenapiWarnings `json:"warnings,omitempty"`
|
||||
}
|
||||
271
internal/slurm/types_shares_test.go
Normal file
271
internal/slurm/types_shares_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package slurm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSharesFloat128TresRoundTrip(t *testing.T) {
|
||||
orig := SharesFloat128Tres{
|
||||
Name: Ptr("cpu"),
|
||||
Value: Ptr(1234.56),
|
||||
}
|
||||
data, err := json.Marshal(orig)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
var decoded SharesFloat128Tres
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if *decoded.Name != *orig.Name {
|
||||
t.Errorf("Name: got %s, want %s", *decoded.Name, *orig.Name)
|
||||
}
|
||||
if *decoded.Value != *orig.Value {
|
||||
t.Errorf("Value: got %f, want %f", *decoded.Value, *orig.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharesUint64TresRoundTrip(t *testing.T) {
|
||||
orig := SharesUint64Tres{
|
||||
Name: Ptr("mem"),
|
||||
Value: &Uint64NoVal{
|
||||
Set: Ptr(true),
|
||||
Number: Ptr(int64(4096)),
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(orig)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
var decoded SharesUint64Tres
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if *decoded.Name != *orig.Name {
|
||||
t.Errorf("Name: got %s, want %s", *decoded.Name, *orig.Name)
|
||||
}
|
||||
if *decoded.Value.Number != *orig.Value.Number {
|
||||
t.Errorf("Value.Number: got %d, want %d", *decoded.Value.Number, *orig.Value.Number)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharesRespMsgRoundTrip(t *testing.T) {
|
||||
orig := SharesRespMsg{
|
||||
Shares: &AssocSharesObjList{
|
||||
{
|
||||
ID: Ptr(int32(1)),
|
||||
Cluster: Ptr("cluster1"),
|
||||
Name: Ptr("root"),
|
||||
Parent: Ptr(""),
|
||||
Shares: &Uint32NoVal{Set: Ptr(true), Number: Ptr(int64(100))},
|
||||
Type: []string{"ASSOCIATION"},
|
||||
Fairshare: &AssocSharesObjFairshare{
|
||||
Factor: Ptr(0.5),
|
||||
Level: Ptr(0.5),
|
||||
},
|
||||
Tres: &AssocSharesObjTres{
|
||||
Usage: &SharesFloat128TresList{
|
||||
{Name: Ptr("cpu"), Value: Ptr(100000.0)},
|
||||
},
|
||||
RunSeconds: &SharesUint64TresList{
|
||||
{Name: Ptr("cpu"), Value: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(3600))}},
|
||||
},
|
||||
},
|
||||
SharesNormalized: &Float64NoVal{Set: Ptr(true), Number: Ptr(1.0)},
|
||||
EffectiveUsage: Ptr(0.75),
|
||||
UsageNormalized: &Float64NoVal{Set: Ptr(true), Number: Ptr(0.75)},
|
||||
Usage: Ptr(int64(50000)),
|
||||
},
|
||||
},
|
||||
TotalShares: Ptr(int64(100)),
|
||||
}
|
||||
data, err := json.Marshal(orig)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
var decoded SharesRespMsg
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if *decoded.TotalShares != *orig.TotalShares {
|
||||
t.Errorf("TotalShares: got %d, want %d", *decoded.TotalShares, *orig.TotalShares)
|
||||
}
|
||||
if len(*decoded.Shares) != len(*orig.Shares) {
|
||||
t.Fatalf("Shares length: got %d, want %d", len(*decoded.Shares), len(*orig.Shares))
|
||||
}
|
||||
got := (*decoded.Shares)[0]
|
||||
want := (*orig.Shares)[0]
|
||||
if *got.Name != *want.Name {
|
||||
t.Errorf("Shares[0].Name: got %s, want %s", *got.Name, *want.Name)
|
||||
}
|
||||
if *got.Cluster != *want.Cluster {
|
||||
t.Errorf("Shares[0].Cluster: got %s, want %s", *got.Cluster, *want.Cluster)
|
||||
}
|
||||
if *got.Shares.Number != *want.Shares.Number {
|
||||
t.Errorf("Shares[0].Shares.Number: got %d, want %d", *got.Shares.Number, *want.Shares.Number)
|
||||
}
|
||||
if len(got.Type) != 1 || got.Type[0] != "ASSOCIATION" {
|
||||
t.Errorf("Shares[0].Type: got %v", got.Type)
|
||||
}
|
||||
if got.Fairshare == nil || *got.Fairshare.Factor != 0.5 {
|
||||
t.Errorf("Shares[0].Fairshare.Factor: got %v", got.Fairshare)
|
||||
}
|
||||
if got.Tres == nil || got.Tres.Usage == nil || len(*got.Tres.Usage) != 1 {
|
||||
t.Fatal("Shares[0].Tres.Usage missing or wrong length")
|
||||
}
|
||||
if *(*got.Tres.Usage)[0].Name != "cpu" {
|
||||
t.Errorf("Shares[0].Tres.Usage[0].Name: got %s", *(*got.Tres.Usage)[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssocSharesObjWrapRoundTrip(t *testing.T) {
|
||||
orig := AssocSharesObjWrap{
|
||||
ID: Ptr(int32(42)),
|
||||
Cluster: Ptr("test-cluster"),
|
||||
Name: Ptr("alice"),
|
||||
Parent: Ptr("team-a"),
|
||||
Partition: Ptr("gpu"),
|
||||
SharesNormalized: &Float64NoVal{Set: Ptr(true), Number: Ptr(0.25)},
|
||||
Shares: &Uint32NoVal{Set: Ptr(true), Number: Ptr(int64(50))},
|
||||
Tres: &AssocSharesObjTres{
|
||||
RunSeconds: &SharesUint64TresList{
|
||||
{Name: Ptr("cpu"), Value: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(7200))}},
|
||||
},
|
||||
GroupMinutes: &SharesUint64TresList{
|
||||
{Name: Ptr("cpu"), Value: &Uint64NoVal{Set: Ptr(true), Number: Ptr(int64(10080))}},
|
||||
},
|
||||
Usage: &SharesFloat128TresList{
|
||||
{Name: Ptr("cpu"), Value: Ptr(50000.5)},
|
||||
{Name: Ptr("mem"), Value: Ptr(20000.0)},
|
||||
},
|
||||
},
|
||||
EffectiveUsage: Ptr(0.33),
|
||||
UsageNormalized: &Float64NoVal{Set: Ptr(true), Number: Ptr(0.33)},
|
||||
Usage: Ptr(int64(30000)),
|
||||
Fairshare: &AssocSharesObjFairshare{
|
||||
Factor: Ptr(0.85),
|
||||
Level: Ptr(0.7),
|
||||
},
|
||||
Type: []string{"USER"},
|
||||
}
|
||||
data, err := json.Marshal(orig)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
var decoded AssocSharesObjWrap
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if *decoded.ID != *orig.ID {
|
||||
t.Errorf("ID: got %d, want %d", *decoded.ID, *orig.ID)
|
||||
}
|
||||
if *decoded.Name != *orig.Name {
|
||||
t.Errorf("Name: got %s, want %s", *decoded.Name, *orig.Name)
|
||||
}
|
||||
if *decoded.Parent != *orig.Parent {
|
||||
t.Errorf("Parent: got %s, want %s", *decoded.Parent, *orig.Parent)
|
||||
}
|
||||
if *decoded.Partition != *orig.Partition {
|
||||
t.Errorf("Partition: got %s, want %s", *decoded.Partition, *orig.Partition)
|
||||
}
|
||||
if *decoded.Shares.Number != *orig.Shares.Number {
|
||||
t.Errorf("Shares.Number: got %d, want %d", *decoded.Shares.Number, *orig.Shares.Number)
|
||||
}
|
||||
if *decoded.Fairshare.Factor != *orig.Fairshare.Factor {
|
||||
t.Errorf("Fairshare.Factor: got %f, want %f", *decoded.Fairshare.Factor, *orig.Fairshare.Factor)
|
||||
}
|
||||
if len(decoded.Type) != 1 || decoded.Type[0] != "USER" {
|
||||
t.Errorf("Type: got %v", decoded.Type)
|
||||
}
|
||||
if decoded.Tres == nil {
|
||||
t.Fatal("Tres is nil")
|
||||
}
|
||||
if len(*decoded.Tres.Usage) != 2 {
|
||||
t.Errorf("Tres.Usage length: got %d, want 2", len(*decoded.Tres.Usage))
|
||||
}
|
||||
if len(*decoded.Tres.RunSeconds) != 1 {
|
||||
t.Errorf("Tres.RunSeconds length: got %d, want 1", len(*decoded.Tres.RunSeconds))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenapiSharesRespRoundTrip(t *testing.T) {
|
||||
orig := OpenapiSharesResp{
|
||||
Shares: &SharesRespMsg{
|
||||
Shares: &AssocSharesObjList{
|
||||
{
|
||||
ID: Ptr(int32(1)),
|
||||
Name: Ptr("root"),
|
||||
Cluster: Ptr("cluster1"),
|
||||
Shares: &Uint32NoVal{Set: Ptr(true), Number: Ptr(int64(100))},
|
||||
Fairshare: &AssocSharesObjFairshare{Factor: Ptr(1.0)},
|
||||
Type: []string{"ASSOCIATION"},
|
||||
},
|
||||
},
|
||||
TotalShares: Ptr(int64(100)),
|
||||
},
|
||||
Meta: &OpenapiMeta{
|
||||
Slurm: &MetaSlurm{
|
||||
Version: &MetaSlurmVersion{Major: Ptr("24"), Minor: Ptr("05"), Micro: Ptr("5")},
|
||||
Release: Ptr("24.05.5"),
|
||||
},
|
||||
},
|
||||
Errors: OpenapiErrors{},
|
||||
Warnings: OpenapiWarnings{},
|
||||
}
|
||||
data, err := json.Marshal(orig)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
var decoded OpenapiSharesResp
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if decoded.Shares == nil {
|
||||
t.Fatal("Shares is nil")
|
||||
}
|
||||
if *decoded.Shares.TotalShares != *orig.Shares.TotalShares {
|
||||
t.Errorf("Shares.TotalShares: got %d, want %d", *decoded.Shares.TotalShares, *orig.Shares.TotalShares)
|
||||
}
|
||||
if decoded.Meta == nil || decoded.Meta.Slurm == nil {
|
||||
t.Fatal("Meta.Slurm is nil")
|
||||
}
|
||||
if *decoded.Meta.Slurm.Release != *orig.Meta.Slurm.Release {
|
||||
t.Errorf("Meta.Slurm.Release: got %s, want %s", *decoded.Meta.Slurm.Release, *orig.Meta.Slurm.Release)
|
||||
}
|
||||
if len(*decoded.Shares.Shares) != 1 {
|
||||
t.Fatalf("Shares.Shares length: got %d, want 1", len(*decoded.Shares.Shares))
|
||||
}
|
||||
if *(*decoded.Shares.Shares)[0].Name != "root" {
|
||||
t.Errorf("Shares.Shares[0].Name: got %s", *(*decoded.Shares.Shares)[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharesEmptyRoundTrip(t *testing.T) {
|
||||
orig := SharesFloat128Tres{}
|
||||
data, err := json.Marshal(orig)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
if string(data) != "{}" {
|
||||
t.Errorf("empty SharesFloat128Tres: got %s, want {}", data)
|
||||
}
|
||||
|
||||
origUint := SharesUint64Tres{}
|
||||
data, err = json.Marshal(origUint)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
if string(data) != "{}" {
|
||||
t.Errorf("empty SharesUint64Tres: got %s, want {}", data)
|
||||
}
|
||||
|
||||
origWrap := AssocSharesObjWrap{}
|
||||
data, err = json.Marshal(origWrap)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
if string(data) != "{}" {
|
||||
t.Errorf("empty AssocSharesObjWrap: got %s, want {}", data)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user