AI/MachineLearning

[DeepLearning] 팔머 팽귄 데이터

리네엔 2022. 9. 29. 18:02

1. 데이터 확인

데이터 불러오기

  • 판다스를 이용해 csv파일 읽어오기
  • traindata = pd.read_csv(datasetFolderPath+trainsetName) traindata.dtypes
  • 널데이터 확인 및 제거데이터 분포 확인하기
  • traindata.isnull().sum() traindata.dropna(inplace=True)
  • target과 다른 정보들간에 관계가 있는지 확인하기
    sns.histplot(x="species",y="island",data=traindata)

    sns.histplot(x="species",y="culmen_length_mm",data=traindata)

    sns.histplot(x="species",y="culmen_depth_mm",data=traindata)

    sns.histplot(x="species",y="flipper_length_mm",data=traindata)

    sns.histplot(x="species",y="body_mass_g",data=traindata)

    sns.histplot(x="species",y="sex",data=traindata)
  • 연관 없는 데이터 드랍
    traindata.drop("sex",inplace=True,axis=1)

2. 데이터 전처리

def makedata(path,Train=True) :
    x = pd.read_csv(path)
    x.dropna(inplace=True)
    x.drop('sex',inplace=True,axis=1)
    for var in ('species','island') :
        _, x[var] = np.unique(x[var],return_inverse=True)

    label = np.ravel(x['species'])
    x.drop(['species'],inplace=True,axis=1)
    x = (x-np.mean(x,axis=0))/np.std(x,axis=0)
    x = torch.from_numpy(x.to_numpy(np.float32))

    label = torch.from_numpy(label)
    y = F.one_hot(label).float()
    print(y.shape)

    return x, y

3. 학습

train_x,train_y = makedata(datasetFolderPath+traindsetName)
test_x,test_y = makedata(datasetFolderPath+testsetName,Train=False)

model = nn.Linear(5,3)
optimizer = torch.optim.SGD(model.parameters(),lr=0.5)
lossfn = nn.CrossEntropyLoss()
epochs =10
for epoch in range(epochs) :
    pred = model(train_x)
    cost = lossfn(pred,train_y)

    optimizer.zero_grad()
    cost.backward()
     optimizer.step()

     rint(f'epoch : {epoch} | cost : {cost.item()}')
print("Done")

4. 테스트

model.eval()
test_loss, correct = 0, 0
with torch.no_grad() :
    predtion = model(test_x)
    test_loss=lossfn(predtion,test_y).item()
    correct += (predtion.argmax(1) == test_y.argmax(1)).type(torch.float).sum().item()
correct /= len(test_y)
print(f'testset Loss : {test_loss:>.3f},Accuracy : {correct*100:>.3f}%')

https://www.wenyanet.com/opensource/ko/611377af0524566054290c16.html

'AI > MachineLearning' 카테고리의 다른 글

[기계학습] 1. 기계학습 개요 및 용어  (0) 2022.03.22