In [4]:
from astropy.visualization import ZScaleInterval
import os
import os.path as op
import numpy as np
import glob
import pandas as pd
import seaborn as sns
import tables
from astropy.table import QTable, Table, Column
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import plotly.express as px
import plotly.graph_objects as go
%matplotlib notebook

In [20]:
#get a list of all the detectids (names of files)

# Determine directory to data cutouts:

if op.exists("cutouts_txt"):
    datadir = "cutouts_txt/"
elif op.exists("/home/jovyan/Hobby-Eberly-Public/HETDEX/shared/AST376R"):
    datadir = "/home/jovyan/Hobby-Eberly-Public/HETDEX/shared/AST376R/cutouts_txt/"
else:
    print("Cannot find data directory path to cutout_txt. Manually enter path")


In [15]:
#datdir = "path to files"
print(datadir)

/home/jovyan/Hobby-Eberly-Public/HETDEX/shared/AST376R/


In [37]:
files = [f for f in glob.glob(datadir + '*.txt')]

detect_list = [int( f.split('/')[-1].split('.')[0]) for f in files]

In [41]:
#grabs the 2D spectra from the cutouts_txt directory (in .txt files). make a dataframe that is readable by tsne
cutout_2d_spec_arrays = []
cutout_wave_arrays = []
cutout_detectids = []

for d in detect_list[0:1000]: #runs the first 1000...can run more/less, change numbers in brackets
    data = np.loadtxt( datadir + str(d)+'.txt')
    wave = data[0]
    spec_2d = data[1:]
    
    cutout_2d_spec_arrays.append(spec_2d.flatten())
    cutout_wave_arrays.append(wave)
    cutout_detectids.append(d)
    
spec_arrs = np.array(cutout_2d_spec_arrays) 
cutout_df = pd.DataFrame(spec_arrs)
print('Total cutouts made: '+str(len(cutout_detectids)))

Total cutouts made: 1000


In [42]:
#runs tsne: play around with components and perplexity (just a number, here I used the square root of the number of objects)
#spits out "x" and "y" values for each detectid/2D image we fed in)
#these are meaningless on their own, but occupy regions when plotted 
X_embedded = TSNE(n_components = 2, perplexity = int(np.sqrt(len(cutout_detectids))),
                  verbose=1).fit_transform(cutout_df)



[t-SNE] Computing 94 nearest neighbors...
[t-SNE] Indexed 1000 samples in 0.000s...
[t-SNE] Computed neighbors for 1000 samples in 0.076s...
[t-SNE] Computed conditional probabilities for sample 1000 / 1000
[t-SNE] Mean sigma: 12.231047
[t-SNE] KL divergence after 250 iterations with early exaggeration: 117.492126
[t-SNE] KL divergence after 1000 iterations: 1.865738


In [43]:
#make a table/dataframe including the detectid and tsne values 
#will be passed into the plotting function
output_Table = Table()
output_Table['detectid'] = cutout_detectids
output_Table['tsne_x'] = X_embedded[:,0]
output_Table['tsne_y'] = X_embedded[:,1]
output_df = pd.DataFrame(np.array(output_Table))

In [44]:
#makes an interactive plot of the tsne x and y values.
scat = px.scatter(output_df,x='tsne_x', y='tsne_y', hover_data=['detectid'])
fig = go.FigureWidget(scat)
fig

FigureWidget({
    'data': [{'customdata': array([[3000215113],
                                   [3000256359],
                                   [3000157024],
                                   ...,
                                   [3000166444],
                                   [3000262599],
                                   [3000136433]]),
              'hovertemplate': 'tsne_x=%{x}<br>tsne_y=%{y}<br>detectid=%{customdata[0]}<extra></extra>',
              'legendgroup': '',
              'marker': {'color': '#636efa', 'symbol': 'circle'},
              'mode': 'markers',
              'name': '',
              'orientation': 'v',
              'showlegend': False,
              'type': 'scatter',
              'uid': 'fc827ccc-5d05-42ca-9d03-1a48003c6cf6',
              'x': array([-5.7932525, -2.7747352, -6.1321774, ..., -3.4930298, -5.472426 ,
                          -6.0453777], dtype=float32),
              'xaxis': 'x',
              'y': array([-1.1957278 , -0.6713427

In [48]:
#plot the 2D spectrum cutout a specific detectid 
sel_det = 3000371665 #can change this to any other detectid
sel = cutout_detectids.index(sel_det)
height=9

spec_2d = np.array(cutout_2d_spec_arrays[sel]).reshape((9, 50))
wave = cutout_wave_arrays[sel]

zscale = ZScaleInterval(contrast=0.5,krej=2.5)
vmin, vmax = zscale.get_limits(values=spec_2d)
plt.figure(figsize=(10,3))
plt.imshow(spec_2d,vmin=vmin, vmax=vmax,extent=[wave[0], wave[-1], -int(height/2.), int(height/2.)], origin="lower",cmap=plt.get_cmap('gray'),interpolation="none")
plt.show()

<IPython.core.display.Javascript object>