diff options
| author | Bobby <[email protected]> | 2025-04-13 18:10:13 +0530 |
|---|---|---|
| committer | Bobby <[email protected]> | 2025-04-13 18:10:13 +0530 |
| commit | 12fb3704db217cc408b662ec73cdd41e028c0e08 (patch) | |
| tree | 4cd3a34898d362773d77a31e03695cb2480286c7 /utils | |
| parent | 1b271061415b33a8f18d1d3d960bc750b9557b69 (diff) | |
| download | ai-12fb3704db217cc408b662ec73cdd41e028c0e08.tar.xz ai-12fb3704db217cc408b662ec73cdd41e028c0e08.zip | |
implement play command; music playback working
Diffstat (limited to 'utils')
| -rw-r--r-- | utils/music/search.go | 418 | ||||
| -rw-r--r-- | utils/music/voice.go | 387 |
2 files changed, 805 insertions, 0 deletions
diff --git a/utils/music/search.go b/utils/music/search.go new file mode 100644 index 0000000..3cef107 --- /dev/null +++ b/utils/music/search.go @@ -0,0 +1,418 @@ +package music + +import ( + "ai/config" + "ai/types" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strings" + "sync" +) + +var ( + youtubeRegex = regexp.MustCompile(`^(https?://)?(www\.)?(youtube\.com|youtu\.?be)/.+`) + spotifyRegex = regexp.MustCompile(`^(https?://)?(open\.)?spotify\.com/.+`) +) + +func IsYouTubeURL(input string) bool { + return youtubeRegex.MatchString(input) +} + +func IsSpotifyURL(input string) bool { + return spotifyRegex.MatchString(input) +} + +func Search(query string, limit int) ([]types.MusicSearchResult, error) { + var wg sync.WaitGroup + wg.Add(2) + + var youtubeResults []types.MusicSearchResult + var spotifyResults []types.MusicSearchResult + var youtubeErr, spotifyErr error + + go func() { + defer wg.Done() + youtubeResults, youtubeErr = SearchYouTube(query, limit/2) + }() + + go func() { + defer wg.Done() + spotifyResults, spotifyErr = SearchSpotify(query, limit/2) + }() + + wg.Wait() + + if youtubeErr != nil && spotifyErr != nil { + return nil, fmt.Errorf("both search errors: youtube: %w, spotify: %w", youtubeErr, spotifyErr) + } + + results := []types.MusicSearchResult{} + + maxLength := max(len(spotifyResults), len(youtubeResults)) + + for i := range maxLength { + if i < len(youtubeResults) { + results = append(results, youtubeResults[i]) + } + if i < len(spotifyResults) { + results = append(results, spotifyResults[i]) + } + } + + if len(results) > limit { + results = results[:limit] + } + + return results, nil +} + +func SearchSpotify(query string, limit int) ([]types.MusicSearchResult, error) { + token, err := getSpotifyToken() + if err != nil { + return nil, fmt.Errorf("spotify token error: %w", err) + } + + searchURL := fmt.Sprintf("https://api.spotify.com/v1/search?q=%s&type=track&limit=%d", url.QueryEscape(query), limit) + req, err := http.NewRequest("GET", searchURL, nil) + if err != nil { + return nil, fmt.Errorf("request creation error: %w", err) + } + + req.Header.Add("Authorization", "Bearer "+token) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("search request error: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("response read error: %w", err) + } + + var searchResponse types.SpotifySearchResponse + if err := json.Unmarshal(body, &searchResponse); err != nil { + return nil, fmt.Errorf("json unmarshal error: %w", err) + } + + results := []types.MusicSearchResult{} + + for _, item := range searchResponse.Tracks.Items { + artistName := "" + if len(item.Artists) > 0 { + artistName = item.Artists[0].Name + } + + thumbnailURL := "" + if len(item.Album.Images) > 0 { + thumbnailURL = item.Album.Images[0].URL + } + + // Format duration as mm:ss + durationSec := item.DurationMs / 1000 + duration := fmt.Sprintf("%02d:%02d", durationSec/60, durationSec%60) + + results = append(results, types.MusicSearchResult{ + Title: item.Name, + Artist: artistName, + URL: item.ExternalUrls.Spotify, + ID: item.ID, + Duration: duration, + Thumbnail: thumbnailURL, + SourceType: types.Spotify, + }) + } + + return results, nil +} + +func SearchYouTube(query string, limit int) ([]types.MusicSearchResult, error) { + apiKey := config.Config.YoutubeAPIKey + searchURL := fmt.Sprintf( + "https://www.googleapis.com/youtube/v3/search?part=snippet&q=%s&key=%s&maxResults=%d&type=video", + url.QueryEscape(query), apiKey, limit, + ) + + resp, err := http.Get(searchURL) + if err != nil { + return nil, fmt.Errorf("search request error: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("response read error: %w", err) + } + + var searchResponse types.YouTubeSearchResponse + if err := json.Unmarshal(body, &searchResponse); err != nil { + return nil, fmt.Errorf("json unmarshal error: %w", err) + } + + results := []types.MusicSearchResult{} + + for _, item := range searchResponse.Items { + videoURL := fmt.Sprintf("https://www.youtube.com/watch?v=%s", item.ID.VideoID) + + results = append(results, types.MusicSearchResult{ + Title: item.Snippet.Title, + Artist: item.Snippet.ChannelTitle, + URL: videoURL, + ID: item.ID.VideoID, + Duration: "00:00", // YouTube API requires a separate call to get duration + Thumbnail: item.Snippet.Thumbnails.High.URL, + SourceType: types.YouTube, + }) + } + + return results, nil +} + +func GetTrackInfo(id string, sourceType types.SourceType) (types.MusicSearchResult, error) { + if sourceType == types.YouTube { + return GetYouTubeInfoByID(id) + } else if sourceType == types.Spotify { + return GetSpotifyInfoByID(id) + } + + return types.MusicSearchResult{}, fmt.Errorf("unsupported source type: %s", sourceType) +} + +func GetYouTubeInfoByID(videoID string) (types.MusicSearchResult, error) { + apiKey := config.Config.YoutubeAPIKey + apiURL := fmt.Sprintf( + "https://www.googleapis.com/youtube/v3/videos?part=contentDetails,snippet&id=%s&key=%s", + videoID, apiKey, + ) + + resp, err := http.Get(apiURL) + if err != nil { + return types.MusicSearchResult{}, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return types.MusicSearchResult{}, err + } + + var response struct { + Items []struct { + Snippet struct { + Title string `json:"title"` + ChannelTitle string `json:"channelTitle"` + Thumbnails struct { + High struct { + URL string `json:"url"` + } `json:"high"` + } `json:"thumbnails"` + } `json:"snippet"` + ContentDetails struct { + Duration string `json:"duration"` + } `json:"contentDetails"` + } `json:"items"` + } + + err = json.Unmarshal(body, &response) + if err != nil { + return types.MusicSearchResult{}, err + } + + if len(response.Items) == 0 { + return types.MusicSearchResult{}, fmt.Errorf("video not found") + } + + item := response.Items[0] + return types.MusicSearchResult{ + Title: item.Snippet.Title, + Artist: item.Snippet.ChannelTitle, + URL: fmt.Sprintf("https://www.youtube.com/watch?v=%s", videoID), + ID: videoID, + Duration: item.ContentDetails.Duration, + Thumbnail: item.Snippet.Thumbnails.High.URL, + SourceType: types.YouTube, + }, nil +} + +func GetSpotifyInfoByID(trackID string) (types.MusicSearchResult, error) { + token, err := getSpotifyToken() + if err != nil { + return types.MusicSearchResult{}, err + } + + apiURL := "https://api.spotify.com/v1/tracks/" + trackID + + req, err := http.NewRequest("GET", apiURL, nil) + if err != nil { + return types.MusicSearchResult{}, err + } + + req.Header.Add("Authorization", "Bearer "+token) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return types.MusicSearchResult{}, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return types.MusicSearchResult{}, err + } + + var trackResponse struct { + ID string `json:"id"` + Name string `json:"name"` + Artists []struct { + Name string `json:"name"` + } `json:"artists"` + Album struct { + Images []struct { + URL string `json:"url"` + } `json:"images"` + } `json:"album"` + DurationMs int `json:"duration_ms"` + ExternalUrls struct { + Spotify string `json:"spotify"` + } `json:"external_urls"` + } + + err = json.Unmarshal(body, &trackResponse) + if err != nil { + return types.MusicSearchResult{}, err + } + + artistName := "" + if len(trackResponse.Artists) > 0 { + artistName = trackResponse.Artists[0].Name + } + + thumbnailURL := "" + if len(trackResponse.Album.Images) > 0 { + thumbnailURL = trackResponse.Album.Images[0].URL + } + + duration := fmt.Sprintf("%02d:%02d", trackResponse.DurationMs/60000, (trackResponse.DurationMs/1000)%60) + + return types.MusicSearchResult{ + Title: trackResponse.Name, + Artist: artistName, + URL: trackResponse.ExternalUrls.Spotify, + ID: trackResponse.ID, + Duration: duration, + Thumbnail: thumbnailURL, + SourceType: types.Spotify, + }, nil +} + +func GetYouTubeForSpotify(title, artist string) (types.MusicSearchResult, error) { + query := fmt.Sprintf("%s %s", title, artist) + + results, err := SearchYouTube(query, 1) + if err != nil { + return types.MusicSearchResult{}, err + } + + if len(results) == 0 { + return types.MusicSearchResult{}, fmt.Errorf("no YouTube results found") + } + + return results[0], nil +} + +func GetYouTubeInfo(ytURL string) (types.MusicSearchResult, error) { + var videoID string + + if strings.Contains(ytURL, "youtu.be") { + parts := strings.Split(ytURL, "/") + videoID = parts[len(parts)-1] + } else if strings.Contains(ytURL, "youtube.com") { + parsedURL, err := url.Parse(ytURL) + if err != nil { + return types.MusicSearchResult{}, err + } + + query := parsedURL.Query() + videoID = query.Get("v") + } + + if videoID == "" { + return types.MusicSearchResult{}, fmt.Errorf("could not extract video ID from URL") + } + + return GetYouTubeInfoByID(videoID) +} + +func GetSpotifyInfo(spotifyURL string) (types.MusicSearchResult, error) { + var trackID string + + if strings.Contains(spotifyURL, "track") { + parts := strings.Split(spotifyURL, "/") + trackID = parts[len(parts)-1] + + // Remove any query parameters + if strings.Contains(trackID, "?") { + trackID = strings.Split(trackID, "?")[0] + } + } else { + return types.MusicSearchResult{}, fmt.Errorf("URL must be a Spotify track URL") + } + + if trackID == "" { + return types.MusicSearchResult{}, fmt.Errorf("could not extract track ID from URL") + } + + return GetSpotifyInfoByID(trackID) +} + +func getSpotifyToken() (string, error) { + clientID := config.Config.SpotifyClientId + clientSecret := config.Config.SpotifyClientSecret + + tokenURL := "https://accounts.spotify.com/api/token" + + data := url.Values{} + data.Set("grant_type", "client_credentials") + + req, err := http.NewRequest("POST", tokenURL, strings.NewReader(data.Encode())) + if err != nil { + return "", err + } + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(clientID, clientSecret) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + var tokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + + if err := json.Unmarshal(body, &tokenResponse); err != nil { + return "", err + } + if tokenResponse.TokenType != "Bearer" { + return "", fmt.Errorf("unexpected token type: %s", tokenResponse.TokenType) + } + + return tokenResponse.AccessToken, nil +} diff --git a/utils/music/voice.go b/utils/music/voice.go new file mode 100644 index 0000000..961a8f5 --- /dev/null +++ b/utils/music/voice.go @@ -0,0 +1,387 @@ +package music + +import ( + "encoding/binary" + "fmt" + "io" + "os" + "os/exec" + "sync" + "time" + + "github.com/bwmarrin/discordgo" + "layeh.com/gopus" +) + +const ( + channels int = 2 // 1 for mono, 2 for stereo + frameRate int = 48000 // audio sampling rate + frameSize int = 960 // size of each audio frame + maxBytes int = (frameSize * 2) * 2 // max size of opus data +) + +// VoiceInstance represents a voice connection +type VoiceInstance struct { + GuildID string + ChannelID string + Connection *discordgo.VoiceConnection + Playing bool + StopChannel chan bool + OpusEncoder *gopus.Encoder + mu sync.Mutex + CurrentTrackID string +} + +var ( + VoiceConnection = make(map[string]*VoiceInstance) + VoiceMutex = &sync.Mutex{} +) + +// Stop stops the current playback +func (v *VoiceInstance) Stop() { + v.mu.Lock() + defer v.mu.Unlock() + + if v.Playing { + select { + case v.StopChannel <- true: + // Signal sent here + default: + // Channel already has a signal + } + close(v.StopChannel) + v.StopChannel = make(chan bool, 1) + v.Playing = false + } +} + +// JoinVoiceChannel makes the bot join a voice channel +func JoinVoiceChannel(s *discordgo.Session, guildID, channelID string) (*VoiceInstance, error) { + VoiceMutex.Lock() + defer VoiceMutex.Unlock() + + // Check if already in a voice channel in this guild + if voice, exists := VoiceConnection[guildID]; exists { + return voice, nil + } + + // not in a voice channel, create a new one + vc, err := s.ChannelVoiceJoin(guildID, channelID, false, true) + if err != nil { + return nil, fmt.Errorf("failed to join voice channel: %w", err) + } + + encoder, err := gopus.NewEncoder(frameRate, channels, gopus.Audio) + if err != nil { + vc.Disconnect() + return nil, fmt.Errorf("failed to create opus encoder: %w", err) + } + + voiceInstance := &VoiceInstance{ + GuildID: guildID, + ChannelID: channelID, + Connection: vc, + Playing: false, + StopChannel: make(chan bool, 1), + OpusEncoder: encoder, + } + + VoiceConnection[guildID] = voiceInstance + return voiceInstance, nil +} + +// LeaveVoiceChannel makes the bot leave a voice channel +func LeaveVoiceChannel(guildID string) error { + VoiceMutex.Lock() + defer VoiceMutex.Unlock() + + voice, exists := VoiceConnection[guildID] + if !exists { + return fmt.Errorf("not in a voice channel") + } + + // Stop current playback + voice.Stop() + + // Disconnect + err := voice.Connection.Disconnect() + if err != nil { + return fmt.Errorf("failed to disconnect from voice channel: %w", err) + } + + // remove from map + delete(VoiceConnection, guildID) + return nil +} + +// IsUserInSameVC checks if the user is in the same voice channel as the bot +func IsUserInSameVC(s *discordgo.Session, guildID, userID string) (bool, string) { + // Voice state + guild, err := s.State.Guild(guildID) + if err != nil { + return false, "" + } + + var userChannelID string + for _, vs := range guild.VoiceStates { + if vs.UserID == userID { + userChannelID = vs.ChannelID + break + } + } + + if userChannelID == "" { + return false, "" // user not in a voice channel + } + + // Check if bot is in a voice channel + VoiceMutex.Lock() + defer VoiceMutex.Unlock() + + voice, exists := VoiceConnection[guildID] + if !exists { + return true, userChannelID // bot not in a voice channel, but no conflict + } + + return voice.ChannelID == userChannelID, userChannelID +} + +// PlayYouTube downloads and plays a YouTube video +func (v *VoiceInstance) PlayYouTube(videoURL, videoID string) error { + fmt.Printf("Starting to play: %s (ID: %s)\n", videoURL, videoID) + + // Create a new stop channel for this playback + var oldStopChan chan bool + + v.mu.Lock() + // If already playing, properly stop the previous playback + if v.Playing { + fmt.Println("Stopping current playback before starting new one...") + // Save the old channel to send the stop signal after we release the lock + oldStopChan = v.StopChannel + // Create a new channel for the new playback + v.StopChannel = make(chan bool, 1) + } else { + v.StopChannel = make(chan bool, 1) + } + + v.Playing = true + v.CurrentTrackID = videoID + stopChan := v.StopChannel + v.mu.Unlock() + + // Send stop signal to old channel if it exists + // Do this outside the lock to avoid deadlock + if oldStopChan != nil { + // Signal the old playback to stop + select { + case oldStopChan <- true: + fmt.Println("Stop signal sent to previous playback") + default: + fmt.Println("Could not send stop signal, channel might be full or closed") + } + // Wait a moment for the previous playback to clean up + time.Sleep(100 * time.Millisecond) + } + + // Ensure temp directory exists + err := os.MkdirAll("./temp", 0755) + if err != nil { + fmt.Printf("Error creating temp directory: %v\n", err) + return fmt.Errorf("failed to create temp directory: %w", err) + } + + // Create a unique filename + fileName := fmt.Sprintf("./temp/%s_%d.mp3", videoID, time.Now().Unix()) + fmt.Printf("Downloading to: %s\n", fileName) + + // Use yt-dlp to download audio + downloadCmd := exec.Command("yt-dlp", "-x", "--audio-format", "mp3", + "--audio-quality", "0", "--no-playlist", "--output", fileName, videoURL) + + // Set up pipes to capture output for debugging + downloadCmd.Stdout = os.Stdout + downloadCmd.Stderr = os.Stderr + + fmt.Println("Starting download...") + err = downloadCmd.Run() + if err != nil { + fmt.Printf("Download error: %v\n", err) + v.mu.Lock() + v.Playing = false + v.mu.Unlock() + return fmt.Errorf("failed to download audio: %w", err) + } + + fmt.Printf("Download complete, starting playback\n") + + // Check if file exists and get its size + fileInfo, err := os.Stat(fileName) + if err != nil { + fmt.Printf("File stat error: %v\n", err) + v.mu.Lock() + v.Playing = false + v.mu.Unlock() + return fmt.Errorf("file stat error: %w", err) + } + fmt.Printf("File size: %d bytes\n", fileInfo.Size()) + + // Ensure file gets deleted after playback + defer os.Remove(fileName) + + // Make sure we're not already in a speaking state + v.Connection.Speaking(false) + time.Sleep(50 * time.Millisecond) + + // Set speaking status + err = v.Connection.Speaking(true) + if err != nil { + fmt.Printf("Speaking error: %v\n", err) + v.mu.Lock() + v.Playing = false + v.mu.Unlock() + return fmt.Errorf("speaking error: %w", err) + } + defer v.Connection.Speaking(false) + + // Use ffmpeg for playback + ffmpeg := exec.Command("ffmpeg", "-i", fileName, "-f", "s16le", "-ar", "48000", "-ac", "2", "pipe:1") + ffmpegout, err := ffmpeg.StdoutPipe() + if err != nil { + fmt.Printf("FFmpeg pipe error: %v\n", err) + return fmt.Errorf("ffmpeg pipe error: %w", err) + } + + ffmpeg.Stderr = os.Stderr + err = ffmpeg.Start() + if err != nil { + fmt.Printf("FFmpeg start error: %v\n", err) + return fmt.Errorf("ffmpeg start error: %w", err) + } + + // Store ffmpeg process for proper cleanup + ffmpegProcess := ffmpeg.Process + defer func() { + ffmpegProcess.Kill() + ffmpeg.Wait() // Wait for the process to exit to avoid zombies + }() + + // Read and send loop + buf := make([]int16, frameSize*channels) + + playbackDone := make(chan error, 1) + go func() { + for { + // Read data from ffmpeg + err = binary.Read(ffmpegout, binary.LittleEndian, &buf) + if err == io.EOF || err == io.ErrUnexpectedEOF { + playbackDone <- nil + return + } + if err != nil { + playbackDone <- fmt.Errorf("error reading from ffmpeg: %w", err) + return + } + + // Encode with opus + opus, err := v.OpusEncoder.Encode(buf, frameSize, maxBytes) + if err != nil { + playbackDone <- fmt.Errorf("opus encoding error: %w", err) + return + } + + // Send to Discord + select { + case v.Connection.OpusSend <- opus: + // Sent successfully + case <-stopChan: + playbackDone <- nil + return + } + } + }() + + // Wait for playback to finish or stop signal + select { + case err := <-playbackDone: + if err != nil { + fmt.Printf("Playback error: %v\n", err) + } else { + fmt.Println("Playback completed normally") + } + case <-stopChan: + fmt.Println("Playback stopped by request") + } + + // Make sure to kill ffmpeg + ffmpegProcess.Kill() + + v.mu.Lock() + v.Playing = false + v.mu.Unlock() + + return nil +} + +// func (v *VoiceInstance) playAudioFile(filename string, stopChan chan bool) error { +// // Start ffmpeg to convert the file to PCM +// ffmpeg := exec.Command("ffmpeg", "-i", filename, "-f", "s16le", "-ar", "48000", "-ac", "2", "pipe:1") +// ffmpegout, err := ffmpeg.StdoutPipe() +// if err != nil { +// return fmt.Errorf("ffmpeg stdout error: %w", err) +// } + +// ffmpeg.Stderr = os.Stderr +// err = ffmpeg.Start() +// if err != nil { +// return fmt.Errorf("ffmpeg start error: %w", err) +// } + +// // Make sure to kill ffmpeg when we're done +// defer ffmpeg.Process.Kill() + +// // Set speaking status +// err = v.Connection.Speaking(true) +// if err != nil { +// return fmt.Errorf("speaking error: %w", err) +// } +// defer v.Connection.Speaking(false) + +// // Create a buffer for reading from ffmpeg +// buf := make([]int16, frameSize*channels) + +// // Read and send loop +// for { +// // Check if we've been asked to stop +// select { +// case <-stopChan: +// return nil +// default: +// // Continue playing +// } + +// // Read data from ffmpeg +// err = binary.Read(ffmpegout, binary.LittleEndian, &buf) +// if err == io.EOF || err == io.ErrUnexpectedEOF { +// // End of file +// return nil +// } +// if err != nil { +// return fmt.Errorf("error reading from ffmpeg: %w", err) +// } + +// // Encode with opus +// opus, err := v.OpusEncoder.Encode(buf, frameSize, maxBytes) +// if err != nil { +// return fmt.Errorf("opus encoding error: %w", err) +// } + +// // Send to Discord +// select { +// case v.Connection.OpusSend <- opus: +// // Sent successfully +// case <-stopChan: +// return nil +// } +// } +// } |
