291 lines
12 KiB
Python
291 lines
12 KiB
Python
from os import strerror
|
|
import aiohttp
|
|
import discord
|
|
import asyncio
|
|
import typing
|
|
import json
|
|
import socket
|
|
import time
|
|
import re
|
|
|
|
from discord.ext import commands
|
|
from typing import Optional, Union
|
|
from urllib.parse import quote
|
|
from . import spotify
|
|
from . import events
|
|
from . import exceptions
|
|
from . import objects
|
|
from . import __version__
|
|
from .utils import ExponentialBackoff, NodeStats
|
|
|
|
|
|
SPOTIFY_URL_REGEX = re.compile(r'https?://open.spotify.com/(?P<type>album|playlist|track)/(?P<id>[a-zA-Z0-9]+)')
|
|
|
|
class Node:
|
|
"""The base class for a node.
|
|
This node object represents a Lavalink node.
|
|
If you want to enable Spotify searching, pass in a proper Spotify Client ID and Spotify Client Secret"""
|
|
def __init__(self, pool, bot: Union[commands.Bot, discord.Client, commands.AutoShardedBot, discord.AutoShardedClient], host: str, port: int, password: str, identifier: str, spotify_client_id: Optional[str], spotify_client_secret: Optional[str]):
|
|
self._bot = bot
|
|
self._host = host
|
|
self._port = port
|
|
self._password = password
|
|
self._identifier = identifier
|
|
self._pool = pool
|
|
|
|
self._websocket_uri = f"ws://{self._host}:{self._port}"
|
|
self._rest_uri = f"http://{self._host}:{self._port}"
|
|
|
|
self._session = aiohttp.ClientSession()
|
|
self._websocket: aiohttp.ClientWebSocketResponse = None
|
|
self._task: asyncio.Task = None
|
|
|
|
self._connection_id = None
|
|
self._metadata = None
|
|
self._available = None
|
|
|
|
self._headers = {
|
|
"Authorization": self._password,
|
|
"User-Id": str(self._bot.user.id),
|
|
"Client-Name": f"Pomice/{__version__}"
|
|
}
|
|
|
|
self._players = {}
|
|
|
|
self._spotify_client_id: str = spotify_client_id
|
|
self._spotify_client_secret: str = spotify_client_secret
|
|
|
|
if self._spotify_client_id and self._spotify_client_secret:
|
|
self._spotify_client: spotify.Client = spotify.Client(self._spotify_client_id, self._spotify_client_secret)
|
|
self._spotify_http_client: spotify.HTTPClient = spotify.HTTPClient(self._spotify_client_id, self._spotify_client_secret)
|
|
|
|
self._bot.add_listener(self._update_handler, "on_socket_event_type")
|
|
|
|
def __repr__(self):
|
|
return f"<Pomice.node ws_uri={self._websocket_uri} rest_uri={self._rest_uri} player_count={len(self._players)}>"
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
""""Property which returns whether this node is connected or not"""
|
|
return self._websocket is not None and not self._websocket.closed
|
|
|
|
@property
|
|
async def latency(self):
|
|
"""Property which returns the latency of the node in milliseconds"""
|
|
start_time = time.time()
|
|
await self.send(op="ping")
|
|
end_time = await self._bot.wait_for(f"node_ping")
|
|
return (end_time - start_time) * 1000
|
|
|
|
@property
|
|
async def stats(self):
|
|
"""Property which returns the node stats at any given time.
|
|
Typically, accessing this property is rare due to the fact that Lavalink automatically sends updated node stats every minutes."""
|
|
await self.send(op="get-stats")
|
|
node_stats = await self._bot.wait_for(f"node_stats")
|
|
return node_stats
|
|
|
|
@property
|
|
def players(self):
|
|
"""Property which returns a dict containing the guild ID and the player object."""
|
|
return self._players
|
|
|
|
@property
|
|
def bot(self):
|
|
"""Property which returns the discord.py Bot object linked to this node"""
|
|
return self._bot
|
|
|
|
@property
|
|
def pool(self):
|
|
"""Property which returns the node pool this node is apart of."""
|
|
return self._pool
|
|
|
|
async def _update_handler(self, data: dict):
|
|
await self._bot.wait_until_ready()
|
|
|
|
if not data:
|
|
return
|
|
|
|
|
|
if data["t"] == "VOICE_SERVER_UPDATE":
|
|
|
|
guild_id = int(data["d"]["guild_id"])
|
|
try:
|
|
player = self._players[guild_id]
|
|
await player._voice_server_update(data["d"])
|
|
except KeyError:
|
|
return
|
|
|
|
elif data["t"] == "VOICE_STATE_UPDATE":
|
|
|
|
if int(data["d"]["user_id"]) != self._bot.user.id:
|
|
return
|
|
|
|
guild_id = int(data["d"]["guild_id"])
|
|
try:
|
|
player = self._players[guild_id]
|
|
await player._voice_state_update(data["d"])
|
|
except KeyError:
|
|
return
|
|
|
|
else:
|
|
return
|
|
|
|
async def _listen(self):
|
|
backoff = ExponentialBackoff(base=7)
|
|
|
|
while True:
|
|
msg = await self._websocket.receive()
|
|
if msg.type == aiohttp.WSMsgType.CLOSED:
|
|
retry = backoff.delay()
|
|
await asyncio.sleep(retry)
|
|
|
|
if not self.is_connected:
|
|
self._bot.loop.create_task(self.connect())
|
|
else:
|
|
self._bot.loop.create_task(self._handle_payload(msg.json()))
|
|
|
|
async def _handle_payload(self, data: dict) -> None:
|
|
op = data.get('op', None)
|
|
if not op:
|
|
return
|
|
|
|
if op == 'stats':
|
|
self._stats = NodeStats(data)
|
|
return
|
|
|
|
if not (player := self._players.get(int(data['guildId']))):
|
|
return
|
|
|
|
if op == 'event':
|
|
await player._dispatch_event(data)
|
|
elif op == 'playerUpdate':
|
|
await player._update_state(data)
|
|
|
|
|
|
async def send(self, **data):
|
|
if not self.available:
|
|
raise exceptions.NodeNotAvailable(f"The node '{self.identifier}' is not currently available.")
|
|
|
|
await self._websocket.send_str(json.dumps(data))
|
|
|
|
def get_player(self, guild_id: int):
|
|
"""Takes a guild ID as a parameter. Returns a pomice Player object."""
|
|
return self._players.get(guild_id, None)
|
|
|
|
async def connect(self):
|
|
"""Initiates a connection with a Lavalink node and adds it to the node pool."""
|
|
await self._bot.wait_until_ready()
|
|
|
|
try:
|
|
self._websocket = await self._session.ws_connect(self._websocket_uri, headers=self._headers, heartbeat=60)
|
|
self._task = self._bot.loop.create_task(self._listen())
|
|
self._pool._nodes[self._identifier] = self
|
|
self.available = True
|
|
return self
|
|
|
|
except aiohttp.WSServerHandshakeError:
|
|
raise exceptions.NodeConnectionFailure(f"The password for node '{self.identifier}' is invalid.")
|
|
except aiohttp.InvalidURL:
|
|
raise exceptions.NodeConnectionFailure(f"The URI for node '{self.identifier}' is invalid.")
|
|
except socket.gaierror:
|
|
raise exceptions.NodeConnectionFailure(f"The node '{self.identifier}' failed to connect.")
|
|
|
|
async def disconnect(self):
|
|
"""Disconnects a connected Lavalink node and removes it from the node pool. This also destroys any players connected to the node."""
|
|
for player in self.players.copy().values():
|
|
await player.destroy()
|
|
|
|
await self._websocket.close()
|
|
del self._pool.nodes[self._identifier]
|
|
self.available = False
|
|
self._task.cancel()
|
|
|
|
async def get_tracks(self, query: str, ctx: commands.Context = None):
|
|
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
|
If you passed in Spotify API credentials, you can also pass in a Spotify URL of a playlist, album or track
|
|
and it will be parsed accordingly.
|
|
You can also pass in a discord.py Context object to get a Context object on any track you search.
|
|
"""
|
|
|
|
if spotify_url_check := SPOTIFY_URL_REGEX.match(query):
|
|
|
|
if not self._spotify_client_id and not self._spotify_client_secret:
|
|
raise exceptions.InvalidSpotifyClientAuthorization("You did not provide proper Spotify client authorization credentials. If you would like to use the Spotify searching feature, please obtain Spotify API credentials here: https://developer.spotify.com/")
|
|
|
|
search_type = spotify_url_check.group('type')
|
|
spotify_id = spotify_url_check.group('id')
|
|
if search_type == "playlist":
|
|
results: spotify.Playlist = spotify.Playlist(client=self._spotify_client, data=await self._spotify_http_client.get_playlist(spotify_id))
|
|
try:
|
|
search_tracks = await results.get_all_tracks()
|
|
tracks = [
|
|
objects.Track(
|
|
track_id=track.id,
|
|
ctx=ctx,
|
|
spotify=True,
|
|
info={'title': track.name or 'Unknown', 'author': ', '.join(artist.name for artist in track.artists) or 'Unknown',
|
|
'length': track.duration or 0, 'identifier': track.id or 'Unknown', 'uri': track.url or 'spotify',
|
|
'isStream': False, 'isSeekable': False, 'position': 0, 'thumbnail': track.images[0].url if track.images else None},
|
|
|
|
) for track in search_tracks
|
|
]
|
|
return objects.Playlist(playlist_info={"name": results.name, "selectedTrack": tracks[0]}, tracks=tracks, ctx=ctx, spotify=True)
|
|
except:
|
|
raise exceptions.SpotifyPlaylistLoadFailed(f"Unable to find results for {query}")
|
|
elif search_type == "album":
|
|
results: spotify.Album = await self._spotify_client.get_album(spotify_id=spotify_id)
|
|
try:
|
|
search_tracks = await results.get_all_tracks()
|
|
tracks = [
|
|
objects.Track(
|
|
track_id=track.id,
|
|
ctx=ctx,
|
|
spotify=True,
|
|
info={'title': track.name or 'Unknown', 'author': ', '.join(artist.name for artist in track.artists) or 'Unknown',
|
|
'length': track.duration or 0, 'identifier': track.id or 'Unknown', 'uri': track.url or 'spotify',
|
|
'isStream': False, 'isSeekable': False, 'position': 0, 'thumbnail': track.images[0].url if track.images else None},
|
|
|
|
) for track in search_tracks
|
|
]
|
|
|
|
return objects.Playlist(playlist_info={"name": results.name, "selectedTrack": tracks[0]}, tracks=tracks, ctx=ctx, spotify=True)
|
|
except:
|
|
raise exceptions.SpotifyAlbumLoadFailed(f"Unable to find results for {query}")
|
|
elif search_type == 'track':
|
|
try:
|
|
results: spotify.Track = await self._spotify_client.get_track(spotify_id=spotify_id)
|
|
return [objects.Track(
|
|
track_id=results.id,
|
|
ctx=ctx,
|
|
spotify=True,
|
|
info={'title': results.name or 'Unknown', 'author': ', '.join(artist.name for artist in results.artists) or 'Unknown',
|
|
'length': results.duration or 0, 'identifier': results.id or 'Unknown', 'uri': results.url or 'spotify',
|
|
'isStream': False, 'isSeekable': False, 'position': 0, 'thumbnail': results.images[0].url if results.images else None},)]
|
|
except:
|
|
raise exceptions.SpotifyTrackLoadFailed(f"Unable to find results for {query}")
|
|
|
|
|
|
else:
|
|
async with self._session.get(url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}", headers={"Authorization": self._password}) as response:
|
|
data = await response.json()
|
|
|
|
load_type = data.get("loadType")
|
|
|
|
if not load_type:
|
|
raise exceptions.TrackLoadError("There was an error while trying to load this track.")
|
|
|
|
elif load_type == "LOAD_FAILED":
|
|
raise exceptions.TrackLoadError(f"There was an error of severity '{data['severity']}' while loading tracks.\n\n{data['cause']}")
|
|
|
|
elif load_type == "NO_MATCHES":
|
|
return None
|
|
|
|
elif load_type == "PLAYLIST_LOADED":
|
|
return objects.Playlist(playlist_info=data["playlistInfo"], tracks=data["tracks"], ctx=ctx)
|
|
|
|
elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED":
|
|
return [objects.Track(track_id=track["track"], info=track["info"], ctx=ctx) for track in data["tracks"]]
|
|
|
|
|