상세 컨텐츠

본문 제목

221007 머신러닝 1주차

카테고리 없음

by hunss 2022. 10. 7. 17:15

본문

선형회귀 위주

숙제 풀이

import os
os.environ['KAGGLE_USERNAME'] = 'nhkmi1001'
os.environ['KAGGLE_KEY'] = 'de5b80115e8959894ce4eb0cf23b4b63' 
 
!kaggle datasets download -d rsadiq/salary  #kaggle에서 데이터베이스 다운받아오기
!unzip salary.zip #알집풀기
 
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam, SGD
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns
from sklearn.model_selection import train_test_split

df = pd.read_csv('Salary.csv')

df.tail(5)
x_data = np.array(df['YearsExperience'], dtype=np.float32)
y_data = np.array(df['Salary'], dtype=np.float32)

x_data = x_data.reshape((-11))
y_data = y_data.reshape((-11))

print(x_data.shape)
print(y_data.shape)

x_train, x_val, y_train, y_val = train_test_split(x_data, y_data, test_size=0.2, random_state=2021)

print(x_train.shape, x_val.shape)
print(y_train.shape, y_val.shape)

model = Sequential([
  Dense(1)
])

model.compile(loss='mean_squared_error', optimizer=SGD(lr=0.01))

model.fit(
    x_train,
    y_train,
    validation_data=(x_val, y_val),
    epochs=100
)
y_pred = model.predict(x_val)

plt.scatter(x_val, y_val)
plt.scatter(x_val, y_pred, color='r')
plt.show()