<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom">
  <channel>
    <title>Deep learning</title>
    <description>Personal blog and resume</description>
    <link>https://woosikyang.github.io/</link>
    <atom:link href="https://woosikyang.github.io/feed.xml" rel="self" type="application/rss+xml"/>
    <pubDate>Thu, 29 Apr 2021 15:00:03 +0000</pubDate>
    <lastBuildDate>Thu, 29 Apr 2021 15:00:03 +0000</lastBuildDate>
    <generator>Jekyll v3.9.0</generator>
    
      <item>
        <title></title>
        <description>&lt;p&gt;안녕하세요 오늘 리뷰할 논문은 From Recognition to Cognition : Visual Commonsense Reasoning 입니다. 기존의 VQA 모델이 단순 질문과 정답을 찾는 것에 그쳤다면 이 논문은 새로운 데이터셋을 제공하면서 질문에 대한 정답 찾기와 그 정답을 찾는 이유까지를 학습하도록 설계되어 있습니다.&lt;/p&gt;

&lt;h2 id=&quot;vcr&quot;&gt;VCR&lt;/h2&gt;

&lt;p&gt;논문에서는 사람들의 상식을 QA에 적용하고자 합니다. 주어진 질문에 대하여 보통 사람들의 정답은 거의 일치하게 되는데, 이는 사람들이 갖고 있는 상식에 근거하기 때문입니다. VCR 데이터셋은 이러한 상식, 즉 정답으로 유도하게 되는 근거를 학습하고자 rationale이라는 ‘이유’ 데이터셋을 추가하였습니다. 이를 통해서 기존의 VQA dataset이 갖고 있는 학습의 한계를 보완했습니다.&lt;/p&gt;

&lt;h2 id=&quot;데이터셋-구성&quot;&gt;데이터셋 구성&lt;/h2&gt;

&lt;ol&gt;
  &lt;li&gt;
    &lt;p&gt;29만개의 객관식 문제&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;29만개의 객관식 문제에 대한 이유&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;11만개의 이미지&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;어렵고 다양한 문제 및 비편향되고 다양한 선택지로 구성&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;정답 문장의 평균 단어 개수 : 7.5 / 이유의 평균 단어 개수 : 16&lt;/p&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/eS1mqh6.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 데이터셋 구성 &lt;/p&gt;

&lt;h2 id=&quot;적용한-모델r2c&quot;&gt;적용한 모델(R2C)&lt;/h2&gt;

&lt;p&gt;본 논문에서 제안하는 모델의 이름은 R2C, Recognition to Cognition Networks입니다. 이미지를 보고 주어진 질문에 답하기 위해서는 여러 추론 과정이 진행됩니다. 첫번째로 질문과 정답을 확인해야 하며 그 과정에서 주어진 이미지와 연계하여 사고해야 합니다. &lt;strong&gt;(Grounding)&lt;/strong&gt; 두번째로는 질문으로부터 얻은 정보와 이미지로부터 얻은 정보, 정답에 대해 통합적인 판단이 이루어집니다. &lt;strong&gt;(Contextualization)&lt;/strong&gt; 마지막으로 통합적인 판단에 대한 상식선에서의 근거가 가치 수행되게 됩니다. &lt;strong&gt;(Reasoning)&lt;/strong&gt; 모델 또한 위와 같은 추론 과정을 통해 학습되도록 구축되었습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/NP8hcFJ.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; R2C 구조 &lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;Grounding&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Grounding을 위해 구축된 모델에서는 각 sequence의 token에 대해 이미지와 질문, 정답에 대한 joint representation을 학습하게 됩니다. 논문에서는 grounding module에서 Bidirectional LSTM을 적용하였고, CNN을 사용하여 이미지의 객체 정보를 추출했는데 bounding box로부터 얻은 ROI 값을 활용하였습니다. 이미지 객체의 레이블 정보까지 사용했습니다.&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;Contextualization&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;질문과 정답에 대한 representation을 얻으면 attention 메커니즘을 통해 정답과 질문, 정답과 미지간의 context 정보를 추출하게 됩니다.&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;Reasoning&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;마지막으로 bidirectional LSTM의 인풋으로 위해서 구한 attened 질문, attened 이미지 및 정답을 사용하여 multi-class cross entropy loss로 문제를 해결합니다. 질문과 정답의 임베딩 방법론으로는 BERT를, 이미지 정보 추출에는 Resnet50을 사용했습니다.&lt;/p&gt;

&lt;h2 id=&quot;결과&quot;&gt;결과&lt;/h2&gt;

&lt;p&gt;실험은 크게 Q-&amp;gt;A (질문과 정답), QA-&amp;gt;R (질문, 정답과 이유), Q-&amp;gt;AR(질문, 정답과 이유) 세 가지로 진행되었습니다. 아래는 실험 결과입니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/bSNHwcv.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Ablations of model  &lt;/p&gt;

&lt;p&gt;기존의 대표적인 방법론인 BOTTOMUPTopDown에 비해 확실히 높은 성능을 보여줍니다. 이미지를 사용하지 않고 BERT만 사용하여 질문과 정답을 학습할 경우에도 생각보다 높은 성능을 갖는다는 점도 눈여겨 볼만 합니다. BERT방법론 자체가 강력한 모델임을 보여준다고 볼 수 있겠습니다.&lt;/p&gt;

&lt;h2 id=&quot;최종-결론&quot;&gt;최종 결론&lt;/h2&gt;

&lt;p&gt;VCR은 기존의 VQA Dataset이 갖고 있는 한계를 보완했다는 점에서 의미가 있으며 인간의 사고와 유사한 방식으로의 학습을 유도한다는 점에서 진정한 VQA task 해결에 가까워질 수 있는 데이터셋입니다. 무엇보다도 이번 2019 ICCV에서 논문에 대한 발표가 있었는데 개인적으로 공부하던 논문이 실제 학회에서 발표되는 모습을 보면서 저 또한 연구자로 의미 있는 발표를 진행하고 싶은 마음이 활활 타오르는 계기였습니다. 앞으로 이 데이터를 활용한 연구를 진행하여 유의미한 향상을 이루어내고 싶습니다.&lt;/p&gt;

</description>
        <pubDate></pubDate>
        <link>https://woosikyang.github.io/2019-11-11-Visual-Commonsense-Reasoning.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/2019-11-11-Visual-Commonsense-Reasoning.html</guid>
        
        
      </item>
    
      <item>
        <title>Tips and tricks for VQA-learnings from 2017 challenge</title>
        <description>&lt;p&gt;안녕하세요 오늘 리뷰할 논문은 Tips and tricks for VQA-learnings from 2017 challenge 입니다. VQA는 높은 관심에 비해 굉장히 어려운 task이므로 아직 확실한 정복이 이루어졌다고 말하기 어려운 분야입니다. 본 논문은 다양한 시도가 이루어진 과거의 논문들을 통해서 강건한 성능의 VQA모델을 제시하고, 어떠한 특징이 VQA에 효과적인지를 말하고 있습니다.&lt;/p&gt;

&lt;h2 id=&quot;주요-특징&quot;&gt;주요 특징&lt;/h2&gt;

&lt;ol&gt;
  &lt;li&gt;
    &lt;p&gt;기존의 single-label 소프트맥스가 아닌 sigmoid output을 사용하여 질문당 다수의 정답을 활용하였습니다.&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;classification이 아닌 soft score값을 사용한 regression 문제로 truth target과 학습이 진행되도록 구성했습니다.&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;모든 비선형 층에 gated tanh activation을 사용했습니다.&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;기존의 CNN 방식이 아닌 bottom-up attention으로부터 얻은 image feature을 사용했습니다.&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;출력층의 가중치 학습 초기화를 위해 후보 정답들로부터 pretrained된 representation값을 사용하였습니다.&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;큰 미니배치와 smart shuffling을 사용한 SGD를 구성했습니다.&lt;/p&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/xzowGEQ.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 제안하는 모델 성능 &lt;/p&gt;

&lt;h2 id=&quot;기존-vqa-approch&quot;&gt;기존 VQA Approch&lt;/h2&gt;

&lt;p&gt;논문에서는 기존 VQA 접근 방식을 소개합니다. 바로 1. question answering을 classification 문제로 바라보며 2. 이 문제를 joint embedding을 활용한 딥 뉴럴 네트워크로 풀고 3. end to end로 지도학습을 통해 학습을 진행하는 방식이지요. 흥미롭게도 좋은 하이퍼파라미터를 사용한다면 굉장히 간단한 모델도 좋은 성능을 보여줄 수 있음이 확인되었습니다. 본 논문에서도 위에서 말한 주요 특징만을 사용해서 간단한 모델에서도 좋은 성능을 가져왔다는 점을 강조합니다.&lt;/p&gt;

&lt;h2 id=&quot;제안하는-모델&quot;&gt;제안하는 모델&lt;/h2&gt;

&lt;p&gt;본 논문에서 제안하는 모델은 아래 그림과 같습니다. 어디서 많이 보신 그림 같으신가요?? 저번 포스팅에서 소개한 bottom-up &amp;amp; top-down attnetion for image captioning and VQA 와 굉장히 유사하다는 것을 알 수 있습니다. 그림에 나와 있듯이 맨 처음에 설명한 조금의 variation을 통해서 좋은 성능을 획득했다는 점을 확인할 수 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/qrULeDo.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 제안하는 모델 &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/qrULeDo.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Bottom-Up Attention Model &lt;/p&gt;

&lt;h2 id=&quot;ablations-of-model&quot;&gt;Ablations of model&lt;/h2&gt;

&lt;p&gt;본 논문에서 눈길을 끈 점은 바로 network ablation table이었습니다. VQA task에 있어서 각자의 architecture가 얼마나, 어떠한 영향을 가져올 수 있는지를 생각해 볼 기회를 준다는 점에서 굉장히 좋은 자료라 생각합니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/wqh7rAI.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Ablations of model  &lt;/p&gt;

&lt;h2 id=&quot;최종-결론&quot;&gt;최종 결론&lt;/h2&gt;

&lt;p&gt;본 논문은 좋은 성능을 보이는 VQA 모델을 제시합니다. 다만 전혀 새로운 구조의 모델이 아닌, 기존에 제안된 간단한 모델을 사용하며 이를 통해서 복잡한 모델이 아닌 VQA에서 design choice와 hyperparameter choice, detailed implementation 만으로도 강건한 성능을 확보할 수 있다는 것을 보여줍니다. 본 논문에서 직접 밝혔듯이 제공되는 결과들은 충분히 좋은 baseline으로 사용할 수 있으며 앞으로 VQA에 대한 깊은 이해와 새로운 시도를 통해 더 나은 모델을 구축할 수 있도록 해야겠습니다. 지금까지 읽어주셔서 감사합니다.&lt;/p&gt;
</description>
        <pubDate>Sun, 10 Nov 2019 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/Tips-and-tricks-for-VQA-learnings-from-2017-challenge.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/Tips-and-tricks-for-VQA-learnings-from-2017-challenge.html</guid>
        
        <category>VQA</category>
        
        
        <category>VQA</category>
        
      </item>
    
      <item>
        <title>Graph Attention Network</title>
        <description>&lt;p&gt;안녕하세요 오늘 리뷰할 논문은 Graph Attention Networks 입니다. Attention은 딥러닝에서 요인 분석 및 성능 향상을 위해 쓰이는 대표적인 기법입니다. GAN 또한 Attention을 사용해서 GNN에서의 성능 향상을 추구합니다. 그럼 논문 요약 진행하겠습니다. 논문의 내용 및 사진 출처는 
&lt;a href=&quot;https://pozalabs.github.io/transformer/&quot;&gt;&lt;strong&gt;Attention&lt;/strong&gt;&lt;/a&gt; ,
&lt;a href=&quot;https://github.com/PetarV-/GAT&quot;&gt;&lt;strong&gt;GAT_github&lt;/strong&gt;&lt;/a&gt; 및 &lt;a href=&quot;https://openreview.net/pdf?id=rJXMpikCZ&quot;&gt;&lt;strong&gt;GAT논문&lt;/strong&gt;&lt;/a&gt; 을 참고하였습니다.&lt;/p&gt;

&lt;h2 id=&quot;들어가며&quot;&gt;들어가며&lt;/h2&gt;

&lt;p&gt;GAT의 가장 핵심적인 아이디어는 각 노드에 대해 multi-head attention을 적용했다는 점입니다. 
Multi-head attention이란 대표적인 attention 모델인 Transformer에서 사용한 기법으로 아래 그림과 같이 전체 차원을 나누어서 linear projection을 여러번 수행하여 더 풍부한 representation을 얻는 기법입니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://pozalabs.github.io/assets/images/multi%20head.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Multi-head Attention&lt;/p&gt;

&lt;p&gt;아래의 그림은 multi-head attention이 적용된 노드1의 모습을 보여줍니다. 서로 다른 형태와 색의 화살표가 multi-head를 통해 구하는 attetnion값을 의미합니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://camo.githubusercontent.com/4381475b2a8cf1bf6213e4dcddf89f87ba8422fc/687474703a2f2f7777772e636c2e63616d2e61632e756b2f7e70763237332f696d616765732f6761742e6a7067&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; GAT Layer&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;strong&gt;Architecture&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;각 노드의 차원이 F인 N개의 노드를 입력값으로 한다면 출력값으로 $F^{‘}$ 인 차원의 N개의 노드를 얻을 수 있습니다.&lt;/p&gt;

\[h = {h_1, h_2, ..., h_N}, h_i \in R^F\]

\[h^{'} = {h^{'}_1, h^{'}_2, ..., h^{'}_N}, h_i \in R^{F^{'}}\]

&lt;p&gt;위와 같이 노드의 입력 차원으로부터 고차원의 representation을 얻기 위해 linear transformation을 행해야 합니다. GAT에서는 고차원의 representation을 얻기 위한 방법으로 가중치 matrix $W$ 를 사용하며 self-attention을 각 노드에 적용합니다. Self-attention을 통해 attention coefficient를 구할 수 있는 데 이는 i번째 노드에 대해 j번째 노드의 특징의 중요성을 의미합니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/7Cuipu1.png&quot; title=&quot;source: imgur.com&quot; /&gt;&amp;lt;/a&amp;gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;attention coefficients&lt;/p&gt;

&lt;p&gt;그래프 구조를 유지하기 위해서 서로 간의 연결되어 있는 이웃 노드의 coefficient만을 계산하며 
노드 i에 대한 j번째 노드의 특징 계산 과정에서 Softmax를 활용한 normalization을 실시합니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/TClz6sd.png&quot; title=&quot;source: imgur.com&quot; /&gt;&amp;lt;/a&amp;gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;Softmax function &lt;/p&gt;

&lt;p&gt;본 논문에서는 한개의 feed-forward 네트워크가 되며 activation fuction으로 비선형 LeakyReLU를 적용했습니다. 최종적으로 아래와 같이 표현됩니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/ChHzxFC.png&quot; title=&quot;source: imgur.com&quot; /&gt;&amp;lt;/a&amp;gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;final coefficients mechanism &lt;/p&gt;

&lt;p&gt;아래는 위의 과정을 도식화한 그림입니다. 최종적으로 softmax를 통해 attention coefficient를 구하게 됩니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/iY3XY8M.png&quot; title=&quot;source: imgur.com&quot; /&gt;&amp;lt;/a&amp;gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;attention mechanism &lt;/p&gt;

&lt;p&gt;앞선 과정에서 구한 attention coefficient 값을 linear combination에 적용하여 output feature를 뽑아내게 됩니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/lRWX6FX.png&quot; title=&quot;source: imgur.com&quot; /&gt;&amp;lt;/a&amp;gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;linear combination using attention coefficients &lt;/p&gt;

&lt;p&gt;그 과정에서 multi-head attention 기법을 적용하는데, 위에서 언급한 바와 같이 K개 만큼의 독립적인 차원으로 나누어 계산을 진행한 후 최종적으로 concat하여 원래의 차원과 동일하게 만들어줍니다. 이 과정을 통해서 더욱 안정되고 풍부한 representation을 구할 수 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/fmID73G.png&quot; title=&quot;source: imgur.com&quot; /&gt;&amp;lt;/a&amp;gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;multi-head attention &lt;/p&gt;

&lt;p&gt;네트워크의 마지막 층에서는 concatenate하지 않고 각 독립된 값의 평균을 취해서 입력값으로 사용합니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/3vUc8t4.png&quot; title=&quot;source: imgur.com&quot; /&gt;&amp;lt;/a&amp;gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;Averaging multi-head representation &lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;strong&gt;모델 적용&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;GAT를 통해 구한 결과를 t-SNE로 시각화한 결과는 아래와 같습니다. attention coefficient 값에 따른 시각화 결과 Cora 데이터셋으로부터 데이터의 특징에 따른 적절한 representation을 얻을 수 있엄음을 확인할 수 있습니다. 추가적으로 논문에서는 GCN 또는 GraphSAGE 방법론보다 더 나은 성능을 보임을 보여주고 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://camo.githubusercontent.com/a1ad7645e034ba75ab4d3380a631fdfc00783553/687474703a2f2f7777772e636c2e63616d2e61632e756b2f7e70763237332f696d616765732f6761745f74736e652e6a7067&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;t-SNE + Attention coefficients on Cora &lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;strong&gt;결론&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;GAT방법론은 Attention mechanism이 딥러닝에서 각광받으면서 GNN에도 적용되어 성능 향상을 가져온 방법입니다. GCN이 갖고 있는 고정된 filter의 한계를 뛰어넘으면서 Attention mechanism의 장점까지 활용할 수 있기에 많은 활용 가능성을 갖고 있습니다. GAT는 제가 공부하고 있는 Visual Question Answering 분야에서도 많이 활용되며 성능을 향상시키는 좋은 기법이며 저 또한 GAT방법론으로부터 더 나은 방법론을 개발하기 위해 연구하고 있습니다.&lt;/p&gt;

</description>
        <pubDate>Tue, 29 Oct 2019 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/Graph-Attention-Network.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/Graph-Attention-Network.html</guid>
        
        <category>Graph</category>
        
        
        <category>Graph</category>
        
      </item>
    
      <item>
        <title>Graph Convolutional Network</title>
        <description>&lt;p&gt;안녕하세요 오늘 리뷰할 논문은 Graph Convolutional Networks 입니다. GCN은 그래프 구조에서 사용하는 Graph Neural Network의 일종으로 2016년에 나온 논문이지만 convolution filter의 특징을 graph에 적용했다는 점에서 Graph 기반 이론의 시작에 적합하다고 생각합니다. 그럼 논문 요약 진행하겠습니다. 논문의 내용 및 사진 출처는 &lt;a href=&quot;https://tkipf.github.io/graph-convolutional-networks/&quot;&gt;&lt;strong&gt;GCN_github&lt;/strong&gt;&lt;/a&gt; 및 &lt;a href=&quot;https://towardsdatascience.com/how-to-do-deep-learning-on-graphs-with-graph-convolutional-networks-7d2250723780&quot;&gt;&lt;strong&gt;GCN_towardsdatascience&lt;/strong&gt;&lt;/a&gt; 를 참고하였습니다.&lt;/p&gt;

&lt;h2 id=&quot;들어가며&quot;&gt;들어가며&lt;/h2&gt;

&lt;p&gt;딥러닝에서 대표적인 Task는 이미지 분야일 것입니다. Convolution filter와 함께 이미지의 분류에서 가공할 성장을 보이면서 딥러닝은 새롭게 주목받기 시작했습니다. 그러나 Convolution filter는 이미지와 같이 고정된 그리드 형태의 이미지에 효과적이라는 한계를 갖습니다. 오늘 소개하는 GCN은 그래프에서 convolution filter와 같은 효과를 통해 그리드 형태가 아닌 데이터에서도 효과적으로 feature extraction 및 학습이 가능하도록 접근했습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://tkipf.github.io/graph-convolutional-networks/images/gcn_web.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 다층 GCN의 예시&lt;/p&gt;

&lt;p&gt;위의 그림은 2개의 은닉층으로 구성된 GCN입니다. 각 은닉층 및 활성함수를 지나 학습이 진행된다는 점에서 기존의 MLP와 크게 다르지 않습니다. 중요한건 각 은닉층이 갖게 되는 구조입니다.&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;strong&gt;Definitions&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;위의 예시와 같이 GCN은 간단한 구조를 갖고 있습니다. N개의 노드와 E개의 엣지로 구성된 그래프 graph(N,E)가 있을 때, 각 노드를 d차원으로 embedding한다면 \(n \times d\) 차원의 input값을 구할 수 있을 것입니다. GNN은 추가로 엣지 정보를 반영한 인접행렬을 사용합니다.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;Graph(N,E)
    &lt;ul&gt;
      &lt;li&gt;input &lt;strong&gt;X&lt;/strong&gt; :  \(n \times d\)&lt;/li&gt;
      &lt;li&gt;adjency matrix &lt;strong&gt;A&lt;/strong&gt; : \(n \times n\)&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;이후 X와 A 정보를 은닉층에 사용하여 그 다음 은닉층으로 전달합니다. \(H^0 = X\) 라 하면 
l번째 은닉층에 대한 값은 은닉층 함수 f를 사용하여 다음과 같이 정의할 수 있습니다.&lt;/p&gt;

\[H^l = f( H^{l-1}, A)\]

&lt;p&gt;각 은닉층 함수 f를 지나 input data에서부터 예측값까지 학습이 진행됩니다. 간단한 예시는 아래와 같습니다.&lt;/p&gt;

\[f( H^{l}, A) = \sigma (AH^{l}W^{l})\]

&lt;p&gt;즉, GCN은 MLP처럼 가중치를 구하는 모델이지만 그 과정에서 그래프 구조에서 얻을 수 있는 정보를 반영하는 방식입니다.&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;strong&gt;Practical Implementation&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;자기 자신에 대한 엣지 추가 (adding self-loops)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;실제 적용을 위해서 노드 간의 관계 정보를 반영하는 인접 행렬에 self-loop을 추가해줘야 합니다. 즉 인접 행렬 A에 항등 행렬 I를 더해주는 접근법을 통해서 하나의 노드에 대한 representation을 구할때 다른 노드와의 관계와 함께 자기 자신의 embedding 까지 고려하도록 만들어줍니다.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;normalization 적용&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;강건한 representation을 구하기 위해 normalization을 적용해야 합니다. 노드간의 관계를 반영해야 하다보니 연결된 엣지의 개수, 즉 degree가 높은 노드와 아닌 노드와의 차이가 심할 수 있습니다. 이러한 차이를 줄이기 위해서 normalization을 통해서 데이터를 scaling해주는 작업을 거칩니다. 그래프 이론에서 인접행렬의 normalization은 degree의 값을 갖는 대각행렬의 역행렬을 인접행렬에 곱해주어서 계산합니다.&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;strong&gt;모델 적용&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;논문에서는 준지도 학습에 GCN을 적용한 결과를 보여줍니다. 아래 그림과 같이 GCN을 사용하여서 그래프 데이터로부터 잘 구분되는 embedding을 구했다는 것을 알 수 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://tkipf.github.io/graph-convolutional-networks/images/karate.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;Zachary's Karate Club Dataset &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://tkipf.github.io/graph-convolutional-networks/images/karate_emb.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;각 노드에 대한 GCN embedding 결과 &lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;strong&gt;결론&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;GCN은 GNN의 기초적인 접근법으로 그래프 구조에서도 convolution filter와 같이 강건한 특징 파악을 가능하게 해주는 구조입니다. 이를 이어서 graphSAGE, GAT 와 같은 논문들이 제안되었는데 다음에는 GAT를 다루도록 하겠습니다.&lt;/p&gt;

</description>
        <pubDate>Mon, 28 Oct 2019 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/Graph-Convolutional-Network.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/Graph-Convolutional-Network.html</guid>
        
        <category>Graph</category>
        
        
        <category>Graph</category>
        
      </item>
    
      <item>
        <title>Bottom-Up and Top-Down Attention for Image Captioning and VQA</title>
        <description>&lt;p&gt;안녕하세요 오늘은 Bottom-Up and Top-Down Attention for Image Captioning and VQA 을 리뷰하고자 합니다. VQA는 제가 그동안 쭉 관심을 가져왔지만 막상 어떠한 방식으로 이루어지는지는 몰랐던 분야입니다. 이번 논문을 통해서 최신 VQA가 어떻게 진행되는지를 알 수 있어서 좋았습니다. 앞으로도 VQA, Image captioning과 같은 이미지와 NLP의 조합을 자주 다룰 예정입니다. 그럼 논문 요약 진행하겠습니다.&lt;/p&gt;

&lt;h2 id=&quot;들어가며&quot;&gt;들어가며&lt;/h2&gt;

&lt;p&gt;이미지와 언어를 이해하는 이미지 캡셔닝, VQA는 컴퓨터 비전과 자연어처리 모두를 아우러야 한다는 점에서 관심을 끌고 있습니다. 무엇보다 visual attention 메커니즘이 사용되었다는 특징이 있는데요, visual attention이란 이미지의 어떠한 지역을 집중해야 하는지를 찾는 메커니즘으로 보시면 되겠습니다.&lt;/p&gt;

&lt;p&gt;이미지를 이해하기에는 크게 두 방법이 있습니다. 첫번째는 이미지 전체를 보고 그 이미지에서 Task에 걸맞는 특징을 찾는 방법입니다. 이를 Top-down 방식이라 합니다. 두 번쨰는 이미지의 픽셀 단위부터 조금씩 파악하여 특징을 찾는 방식이며 이를 Bottom-up 이라고 합니다. 현재 대부분의 visual attention 메커니즘은 top-down 방식에 근거하였으며 이는 인간의 인지 시스템과 유사하여 좋은 결과를 가져왔습니다.&lt;/p&gt;

&lt;p&gt;그러나 Top-down 방식은 이미지의 어떠한 부분을 정확히 봐야 할지에 대한 근거가 부족하다는 점에서 주어진 이미지의 특징을 백퍼센트 활용하지 못한다는 한계를 갖습니다. 아래 사진을 본다면 인간의 인식처럼 task에 적합한 부분을 봐야 하는 경우가 기존의 top-down 방식으로는 한계를 갖습니다. 기존의 방식으로는 그저 동일 크기의 grid로만 판단을 하게 되기 때문입니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/3hHyTRW.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림 1 &lt;/p&gt;

&lt;p&gt;위의 그림과 같이 적절한 bottom-up 모델을 활용해서 주어진 이미지의 어떤 지역을 특정해서 봐야할지를 안다면 이미지에 대한 정확한 이해가 가능해질 것입니다. 따라서 본 논문은 기존의 Top-down 방식과 개선된 bottom-up 방식의 결합을 통해 더 높은 수준의 Image Captioning, VQA 모델을 만드는 데에 있습니다.&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;strong&gt;Bottom-Up Attention Model&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;본 논문에서는 Top-down방식과 결합한 Bottom-up attention model로 Faster R-CNN을 제안합니다. Faster R-CNN은 RPN이라는 후보 영역 제안 신경망과 Object Detector인 Fast R-CNN의 결합을 통해 빠른 속도와 높은 성능의 Object Detection을 가능케 한 모델입니다. ImageNet으로 pre-trained된 Resnet-101 모델을 활용하여 Faster R-CNN을 진행하였고 이를 통해 주어진 이미지에서 detect할 지역과 label을 설정하였습니다. 다만 기존의 Faster R-CNN에서 attribute를 예측하는 attribute predictor를 추가함으로써 후보 영역의 class를 더 잘 예측하도록 하였다는 차이점이 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/g5XHsYj.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Faster R-CNN 예시 &lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;논문에서는 주어진 bottom-up attention 모델을 각자 서로 다른 Top-down 모델과의 결합을 통해 Captioning과 VQA Task를 진행합니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Captioning Model&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;캡셔닝에서는 기존 출력된 부분 시퀀스를 문맥으로 사용하며, 각 caption generation에 feature 가중치를 계산하기 위해서 soft top-down attention을 사용합니다. 캡셔닝 모델에서는 bottom-up attention이 없어도 좋은 성능을 보인다고 합니다.  두개의 LSTM 층으로 구성되어 있으며 각 층이 서로 다른 부분을 담당합니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/3SMD9Vq.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;Caption 모델의 구조 &lt;/p&gt;

&lt;p&gt;위 그림의 Caption 모델의 구조를 보여줍니다. 첫번째 층은 Top-Down attnetion LSTM 층으로 인풋으로는 이전 시점의 language LSTM / 이미지의 mean pooling 값 / 이전까지 생성된 단어, 이렇게 세 가지를 사용합니다. 이 세 가지 인풋이 소프트 어텐션을 통해 \(h_t\) 를 생성하며 다시 한번 bottom-up으로 구한 이미지의 mean pooling값을 활용하여 두번째 LSTM 층인 Language LSTM으로 들어갈 인풋 \(v_t^{\^}\) 를 생성하게 됩니다.&lt;/p&gt;

&lt;p&gt;두번째 LSTM층을 지난 최종 출력 y는 일련의 단어가 되며 각 시점의 조건부 분포의 곱을 통해 최종 출력 문장이 결정됩니다. 논문에서는 다양한 조건에 맞춘 loss function을 제공합니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/9jIy0Tm.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 최종 문장 출력 공식 &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/RWFKIfm.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 다양한 loss 함수 &lt;/p&gt;

&lt;p&gt;&lt;strong&gt;VQA Model&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;VQA 모델에서는 질문인 Question representation을 문맥으로 사용하고 마찬가지로 soft attention을 사용합니다. 전체 모델 구조는 아래 그림과 같습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/pgOA7bM.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; VQA 모델 구조 &lt;/p&gt;

&lt;p&gt;그림과 같이 이미지와 질문을 모두 사용하는 joint multi-modal embedding 구조입니다. 이미지 feature를 생성할때 Question의 representation을 활용하며 최종적으로 question과 image feature의 concat를 통해서 후보 답변에 해당하는 예측 점수를 계산합니다.&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;strong&gt;실험결과&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;MSCOCO 와 VQA v2.0 데이터셋을 활용하여 실험을 진행하였고 예상대로 기존의 방법론보다 더 앞선 성능을 보여줍니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/IOBPP8X.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;캡셔닝 모델 결과 &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/6dUcGag.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;VQA 모델 결과 &lt;/p&gt;

&lt;p&gt;정량적 분석 측면에서도 이미지에서 어떤 부분을 봐야할지에 대한 파악이 잘 이루어졌음이 확인되었습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/33T4jcI.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;생성된 캡셔닝의 어텐션 지역&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;&lt;strong&gt;결론&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;본 논문은 기존의 Top-down 모델에 더해 bottom-up 모델을 결합한 방식을 제안하였습니다. 이를 통해서 attention이 더 자연스럽게 task에 반영될 수 있도록 구현하였고 이미지 캡셔닝과 VQA 모델에서 좋은 성능을 보여줍니다.&lt;/p&gt;

</description>
        <pubDate>Thu, 28 Feb 2019 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/Bottom-Up-and-Top-Down-Attention-for-Image-Captioning-and-VQA.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/Bottom-Up-and-Top-Down-Attention-for-Image-Captioning-and-VQA.html</guid>
        
        <category>VQA</category>
        
        
        <category>VQA</category>
        
      </item>
    
      <item>
        <title>RCNN 부터 Mask R-CNN까지 (2) Faster RCNN</title>
        <description>&lt;h1 id=&quot;faster-r-cnn&quot;&gt;Faster R-CNN&lt;/h1&gt;

&lt;p&gt;안녕하세요 본 포스트에서는 Faster R-CNN을 다루도록 하겠습니다. Faster R-CNN은 전 버전이라 할 수 있는 Fast RCNN의 한계를 보완하고 실제 Detection시에도 빠른 속도를 보여주었기에 굉장히 주목을 받은 모델입니다. Faster RCNN이 이처럼 큰 주목을 받게 된 것은 Fast RCNN에서 오랜 속도를 만들게 한 요인인 region proposal 방식을 딥러닝 구조 안에 녹여 냈기 때문입니다. 차후 상술하겠지만 region proposal을 CPU가 아닌 GPU를 활용하여 신경망 구조 안에서 해결했기에 실제 detection에서도 0.198초라는 굉장한 속도 개선을 이루어냅니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/9odBZIP.png&quot; /&gt;&lt;/p&gt;
&lt;p align=&quot;center&quot;&gt; Faster R-CNN구조 &lt;/p&gt;

&lt;h2 id=&quot;들어가며&quot;&gt;들어가며&lt;/h2&gt;

&lt;p&gt;faster r cnn은 두 개의 모듈로 구성됩니다. 첫번째는 region proposal을 하는 deep conv network 이고 두 번째는 제안된 영역을 사용하는 fast rcnn입니다. 마치 어텐션 매커니즘처럼 RPN모듈은 Fast R-CNN이 어디를 봐야 할지를 알려줍니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;RPN&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Fast R-CNN의 가장 핵심적인 구조가 RPN입니다. RPN은 Conv로부터 얻은 feature map의 어떠한 사이즈의 이미즈를 입력하고 출력으로 직사각형의 object score와 object proposal을 뽑아냅니다. 이 과정은 fully conv network로 진행됩니다. 논문의 목적이 fast r-cnn의 O.D 네트워크와의 계산을 공유하는 것이기에 두 네트워크가 공통의 conv 층을 공유하는 것을 가정합니다. 실험에서는 5 층으로 구성된 ZF모델과 13개의 층으로 구성된 VGG16 모델을 사용했습니다.&lt;/p&gt;

&lt;p&gt;region proposal을 만들기 위해 feature map의 마지막 conv 층을 작은 네트워크가 sliding하도록 합니다. 이 작은 네트워크는 입력단으로 n*n을 받습니다. 각 슬라이딩 윈도우는 저차원으로 매핑됩니다.(ZF는 256차원, VGG는 512차원, 그후 Relu적용)&lt;/p&gt;

&lt;p&gt;그 후 두 개의 FCN을 통해 regression과 classification을 수행합니다. 실험에서 n=3을 적용하였고 각각 ZF모델은 171픽셀, VGG모델은 228픽셀값을 사용합니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Anchors&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;각 슬라이딩 윈도우에서 다양한 후보 영역을 동시적으로 예측하는 데 이때 최대 가능 proposal을 k라 합니다. 이때의 k는 미리 정해진 파라미터입니다. 그 후 중간층을 지나 최종적으로 regression 층은 4k를, clss층은 2k만큼의 출력을 갖게 됩니다. 이처럼 후보 영역이 될 수 있는 k는 anchor라고 부르며 각 스케일과 비율에 따라 달라집니다. 실험에서는 3개의 스케일과 3개의 비율은 사용하여 k=9개의 앵커를 사용했습니다. 이처럼, 미리 정해진 앵커를 사용하는 것은 image pyramid 처럼 크기를 조정할 필요도, multi scaled slinding window처럼 filter 크기를 변경할 필요도 없는 매우 효율이 좋은 방식이 됩니다. W*H크기만큼의 conv feature map에서 WHk만큼의 앵커가 존재 합니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/bcij0ZI.png&quot; /&gt;&lt;/p&gt;
&lt;p align=&quot;center&quot;&gt; RPN &lt;/p&gt;

&lt;p&gt;앵커는 translation invariant (이동불변성)이라는 특징을 갖기에 레이블의 이동에도 강건한 특징을 갖습니다. 또한 파라미터 수를 감소시켜 계산을 덜 복잡하게 만들어줍니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;RPN Loss Function&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;RPN의 학습은 object인지 아닌지, regressor의 값 구하기로 두가지로 나뉩니다. 아래는 RPN의 로스 함수를 보여줍니다. \(N_\cls\) 는 미니배치에 사용된 roI를, \(N_\reg\)는 실험에 사용한 RoI의 개수를 의미합니다. RPN에서 가장 높은 IOU, 혹은 IOU 0.7 이상을 레이블이 있는 포지티브, IOU 0.3이하를 negative라고 부르며 이를 활용하여 학습이 진행됩니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/Y6ES5fA.png&quot; /&gt;&lt;/p&gt;
&lt;p align=&quot;center&quot;&gt; Loss fucnction &lt;/p&gt;

&lt;h2 id=&quot;faster-r-cnn-학습과정&quot;&gt;Faster R-CNN 학습과정&lt;/h2&gt;

&lt;ol&gt;
  &lt;li&gt;RPN은 ImageNet을 사용하여 학습된 모델로부터 초기화되어 region proposal task를 위해 end to end로 학습됩니다.&lt;/li&gt;
  &lt;li&gt;윗 단계에서 학습된 RPN을 사용하여 Fast R-CNN 모델의 학습을 진행합니다. (초기화는 ImageNet의 학습 모델로)&lt;/li&gt;
  &lt;li&gt;초기화를 위의 네트워크를 사용하여 RPN을 학습하는데 공통된 Conv layer는 고정하고 RPN에만 연결된 층만 학습합니다.&lt;/li&gt;
  &lt;li&gt;공유된 Conv layer를 고정시키고 Fast R-CNN의 학습을 진행합니다.&lt;/li&gt;
&lt;/ol&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/xYCyHKY.png&quot; /&gt;&lt;/p&gt;
&lt;p align=&quot;center&quot;&gt; Faster R-CNN 학습과정 &lt;/p&gt;

&lt;h2 id=&quot;실험-및-결과&quot;&gt;실험 및 결과&lt;/h2&gt;

&lt;p&gt;PASCAL VOC 2007, 2012 및 MSCOCO 데이터를 사용하였습니다. 
합성곱신경망 모델로 ZF와 VGG16 모델을 사용합니다. 
다양한 실험을 통해서 RPN,Conv 공유, multi task가 성능 향상을 가져온 다는 것이 확인되었습니다. 무엇보다 SS보다 훨씬 빠른 속도를 보입니다. 데이터를 많이 추가할수록 성능 향상을 이룰 수 있었다는 점도 보였습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/d4GaTrH.png&quot; /&gt;&lt;/p&gt;
&lt;p align=&quot;center&quot;&gt; 실험결과 &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/m3dVQRh.png&quot; /&gt;&lt;/p&gt;
&lt;p align=&quot;center&quot;&gt; 속도비교 1 &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/c2eNWkc.png.png&quot; /&gt;&lt;/p&gt;
&lt;p align=&quot;center&quot;&gt; 속도비교 2 &lt;/p&gt;

&lt;p&gt;최종적으로, Faster R-CNN은 RPN이란 구조를 통해 cost-free한 region proposal 방법론을 제안하였으며, 빠른 속도와 높은 정확도의 object detection을 이루어냈습니다.&lt;/p&gt;
</description>
        <pubDate>Sun, 17 Feb 2019 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/faster-rcnn.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/faster-rcnn.html</guid>
        
        <category>Object Detection</category>
        
        
        <category>Object</category>
        
        <category>Detection</category>
        
      </item>
    
      <item>
        <title>RCNN 부터 Mask R-CNN까지 (1) R-CNN ~ Fast R-CNN</title>
        <description>&lt;h1 id=&quot;fast-r-cnn&quot;&gt;Fast R-CNN&lt;/h1&gt;

&lt;p&gt;안녕하세요. 이번 포스트에서는 Object Detection을 정리해보려 합니다. 우선, Object Detection이란 Classification보다 더 어려운 task입니다. 분류가 단순히 이미지의 클래스를 판별하는 작업이라면, object detection은 이미지 안에 있는 여러 레이블을 정확히 분류하는 동시에 레이블의 위치까지도 파악해야 하기 때문입니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/nEfBXqw.jpg&quot; /&gt;&lt;/p&gt;
&lt;p align=&quot;center&quot;&gt; Object Detection은 정교한 작업이 요구됩니다. &lt;/p&gt;

&lt;p&gt;Object Detection의 성능을 눈에띄게 높인 연구로 R-CNN을 얘기할 수 있겠습니다. R-CNN으로부터 SPPnet, Fast R-CNN 등이 제안되었고 이에 본 블로그에서는 차후 Faster R-CNN을 넘어 Mask R-CNN까지 다루는 것을 목표로 합니다.&lt;/p&gt;

&lt;p&gt;그럼 RCNN을 통해 본격적으로 시작해보겠습니다.&lt;/p&gt;

&lt;hr /&gt;

&lt;h2 id=&quot;r-cnn&quot;&gt;&lt;strong&gt;R-CNN&lt;/strong&gt;&lt;/h2&gt;

&lt;p&gt;RCNN은 CNN을 본격적으로 이용하여 Object Detection에서 높은 성능을 보였다는 점에서 주목을 끌었습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/mfvzydg.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/HV29CQH.png&quot; /&gt;&lt;/p&gt;
&lt;p align=&quot;center&quot;&gt; R-CNN 학습 구조 &lt;/p&gt;

&lt;p&gt;R-CNN의 학습은 다음과 같이 진행됩니다.&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;우선 이미지에 대한 후보 영역(region proposal)을 생성합니다. (약 2000여개)&lt;/li&gt;
  &lt;li&gt;각 region proposal마다 고정된 크기로 wraping/crop하여 CNN 인풋으로 사용합니다. 여기서 CNN은 이미 ImageNet을 활용한 pre-trained된 네트워크를 사용합니다.&lt;/li&gt;
  &lt;li&gt;CNN을 통해 나온 feature map을 활용하여 SVM을 통한 분류, regressor를 통한 bounding box regression을 진행합니다.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;R-CNN은 CNN의 invariace 특징을 잘 활용하여 object detection에서 좋은 성능을 가져왔습니다. 그러나 학습이 세 단계의 multi-stage로 구성되어 있고, selective search에 해당하는 region proposal 만큼 CNN을 돌려야 하며 큰 저장 공간을 요구하며 무엇보다도 느리다는 단점이 존재합니다.&lt;/p&gt;

&lt;p&gt;이러한 R-CNN의 단점을 보완하고자 제안된 연구가 SPPnet입니다.&lt;/p&gt;

&lt;hr /&gt;

&lt;h2 id=&quot;sppnet&quot;&gt;&lt;strong&gt;SPPnet&lt;/strong&gt;&lt;/h2&gt;

&lt;p&gt;SPPnet은 R-CNN에서 가장 크게 나타나는 속도 저하의 원인인 각 region proposal마다의 CNN feature map 생성을 보완하였고 이를 통해 학습시 3배, 실제 사용시 10-100배라는 속도 개선을 이루어냈다는 장점을 보여줍니다. 이를 가능케한 구조는 무엇보다도 region proposal에 바로 CNN을 적용하는 것이 아니라 이미지에 우선 CNN을 적용하여 생성한 feature map을 region proposal에 사용했기 때문입니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/ytLTxfD.png&quot; /&gt;&lt;/p&gt;
&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/fuIB1bY.png&quot; /&gt;&lt;/p&gt;
&lt;p align=&quot;center&quot;&gt; SPPnet 학습 구조 &lt;/p&gt;

&lt;p&gt;SPPnet은 Spatial Pyramid Pooling 이라는 특징을 같는 구조를 활용하여 임의 사이즈의 이미지를 모두 활용할 수 있도록 하였습니다. SPP layer는 쉽게 말해서 이미지의 사이즈와 상관없이 특징을 잘 반영할 수 있도록 여러 크기의 bin을 만들고 그 bin값을 활용하는 구조입니다. 결론적으로, SPPnet은 속도를 크게 향상 시켰고, 고정된 이미지만을 필요로 하지 않는다는 장점을 갖게 됩니다.&lt;/p&gt;

&lt;p&gt;다만 한계점도 존재합니다. 우선 R-CNN과 같은 학습 파이프라인을 갖고 있기에 multi-stage로 학습이 진행됩니다. 따라서 저장 공간을 요구하게 되고 학습이 여전히 빠르게 진행되기는 어렵게 됩니다. 또한 위의 그림과 같이 CNN의 파라미터가 학습이 되지 못하기에 Task에 맞는 fine-tuning이 어려워집니다.&lt;/p&gt;

&lt;p&gt;지금까지 다룬 R-CNN과 SPPnet의 장점을 가져오고 단점을 보완하고자 제안된 결과물이 바로 Fast R-CNN이 되겠습니다.&lt;/p&gt;

&lt;hr /&gt;

&lt;h2 id=&quot;fast-r-cnn-1&quot;&gt;&lt;strong&gt;Fast R-CNN&lt;/strong&gt;&lt;/h2&gt;

&lt;p&gt;Fast R-CNN을 다루도록 하겠습니다.&lt;/p&gt;

&lt;p&gt;Fast R-CNN은 R-CNN의 단점을 보완하기 위해 고안되었습니다. 기존의 R-CNN에 비해 더 빠른 처리 속도와 더 높은 정확도를 보장합니다. 
논문에서는 R-CNN과 SPPnet의 단점으로 아래 세 가지를 언급합니다.&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;학습이 여러 단계로 진행되며&lt;/li&gt;
  &lt;li&gt;그로 인해 학습에 많은 시간과 GPU 계산 용량이 요구된다.&lt;/li&gt;
  &lt;li&gt;또한 실제 object detect에 있어서 오랜 시간이 걸린다.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;R-CNN은 연산을 공유하지 않고 모든 object proposal에 각각 convnet을 적용하기에 굉장히 느립니다. SPPnet은 이 단점을 보완하기 위해 하나의 conv feature map을 통해 object proposal을 접근하였는데요, 이를 통해 연산 공유가 가능하도록 하여 보다 빠른 계산이 가능해졌습니다. 그러나 SPPnet 또한 다단계 파이프라인 사용과 같은 단점을 가지고 있으며 오늘 소개할 논문인 Fast R-CNN에서는 두 이전 연구의 단점을 보완함을 강조하고 있습니다.&lt;/p&gt;

&lt;h2 id=&quot;fast-r-cnn-구조&quot;&gt;Fast R-CNN 구조&lt;/h2&gt;

&lt;p&gt;Fast R-CNN의 전체 구조는 아래 그림 1과 같습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/G0hwkMF.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림 1 &lt;/p&gt;

&lt;p&gt;입력으로 전체 이미지와 object proposal을 사용합니다. 네트워크 과정을 통해서 Conv feature map이 생성되며 각 RoI에 대해 feature map으로부터 고정된 길이의 벡터를 출력합니다. 최종적으로 FC층을 지나면서 각 RoI에 대해 softmax 확률값과 class별 bounding box regression offsets을 출력합니다. 전체 학습은 end to end로 진행됩니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;The RoI pooling layer&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;RoI pooling layer는 Conv를 통해 생성된 feature map에서 유효한 RoI 특징을 저차원으로 매핑하기 위해 H*W로의  max pooling을 사용합니다. 여기서 H와 W는 hyperparameter입니다. 논문에서 RoI는 직사각형 모양을 띄며 (r,c,h,w)의 튜플 형태로 정의됩니다. (r,c)는 위, 왼쪽 코너를 의미하며 (h,w)는 높이와 너비를 의미합니다. RoI pooling layer는 앞서 SPPnet에서 사용한 SPP layer의 하나의 pyramid level만을 이용한 특수 현상으로 이해하셔도 동이합니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/E3uBBpa.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; RoI pooling layer 예시 &lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Initializing from pre-trained networks&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;논문에서는 미리 학습된 3개의 Imagenet 네트워크를 초기화시 적용하였습니다. 각각 5개의 max pooling layer와 5~13개의 conv layer를 가진 네트워크이며 Fast R-CNN에 적용되면서 크게 3가지 변화가 적용되었습니다.&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;마지막 max pooling layer는 첫 fc layer와 호환되는 RoI pooling layer로 대체되었습니다.&lt;/li&gt;
  &lt;li&gt;네트워크 마지막의 fc layer와 softmax는 앞서 언급한 바와 같이 2개의 서로 다른 layer로 대체되었습니다.&lt;/li&gt;
  &lt;li&gt;네트워크는 이미지와 region proposal 두 개의 입력을 받을 수 있도록 수정되었습니다.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;&lt;strong&gt;Fine-tuning for detection&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;detection을 위해서 Fast R-CNN은 R-CNN&amp;amp;SPPnet의 region-wise sampling이 아닌 hierarchical sampling을 사용합니다. 따라서 N개의 이미지를 미리 뽑고 그 중 R개의 RoI를 뽑아서 학습이 사용합니다. 논문에서 N=2, R=128을 사용하였고 약 64배의 빠른 학습이 가능하다고 말합니다. 수직적 구조로 인해 수렴이 늦어질수도 있겠지만 학습 결과 수렴 속도에 큰 영향을 미치지 않는다는 것이 언급되어 있습니다. 무엇보다도 Fast R-CNN은 최종 classifier와 regression까지 단방향 단계인 single stage로 fine-tuning이 가능하다는 장점을 갖습니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Multi-task loss&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Fast R-CNN은 두 개의 출력층을 갖습니다. 분류의 경우 각 RoI별 클래스에 속할 사후 확률 값을, 회귀의 경우 bounding box regression 값을 출력하며 두 출력에 대한 ground truth를 u, v로 봅니다. 최종 손실 함수는 아래와 같습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/VfRd3T0.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Loss function &lt;/p&gt;

&lt;p&gt;&lt;strong&gt;미니 배치 샘플링&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;앞서 언급했듯이, N=2, R=128로 미니배치를 구성합니다. RoI의 25%를 전체 object proposal에서 IoU가 0.5 이상인 경우로 구하고 나머지를 0.1~0.5 사이 값으로(배경으로 인식) 구합니다. 학습 과정에서 이미지는 50%의 확률로 수평으로 뒤집어집니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;실제 detection&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;보통 2000개의 RoI를 224*224 스케일과 비슷하게 사용하며 각 RoI마다 사후 class분포 값과 bb 예측값을 산출하여 detecion confidence를 부여합니다. 
이러한 confidence값을 사용하여 non-maximum suppression을 통해 region proposal의 수를 2000여개로 줄입니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Truncated SVD&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;특이값분해를 사용하여 detection을 더욱 빠르게 할 수 있습니다. 
U : u * t 를, V : v*t 의 차원을 갖고 있으며&lt;/p&gt;

&lt;p&gt;\(W \sim U \sum_t V^T\)
로 표현할 수 있습니다. 이 경우 파라미터 수가 uv -&amp;gt; t(u+v)만큼 감소하게 됩니다.&lt;/p&gt;

&lt;h2 id=&quot;실험결과&quot;&gt;실험결과&lt;/h2&gt;

&lt;p&gt;VOC07, 2010, 2012에서 SOTA의 mAP를 보여줬습니다. 무엇보다도 논문에 앞에서 밝혔듯이 R-CNN과 SPPnet에 비해 굉장한 속도 개선을 이루어 냈습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/xXQsRw1.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 실험결과 mAP성능 &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/gYKd37p.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 실험결과 속도개선 &lt;/p&gt;

&lt;p&gt;그러나 아직까지 실제 detection에서 걸리는 시간은 충분히 빠르지 못하는데 이는 다음 포스트인 Faster R-CNN에서 자세히 다루도록 하겠습니다.&lt;/p&gt;
</description>
        <pubDate>Sat, 02 Feb 2019 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/fast-rcnn.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/fast-rcnn.html</guid>
        
        <category>Object Detection</category>
        
        
        <category>Object</category>
        
        <category>Detection</category>
        
      </item>
    
      <item>
        <title>Semi-supervised learning with Ladder Network</title>
        <description>&lt;p&gt;안녕하세요 오늘 리뷰할 논문은 준지도학습 분야의 Semi-supervised learning with Ladder Network 입니다. 유투브에서 고려대학교 DSBA 연구실 채널에 가시면 제가 이 논문을 가지고 발표한 자료를 시청하실 수 있습니다. 아무쪼록 이 포스트를 보시는 분에게 도움이 되기를 바랍니다. 그럼 논문 리뷰 시작하겠습니다. 포스트에서 사용한 그림의 출처는 제가 만든 자료와 리뷰할 본 논문입니다.&lt;/p&gt;

&lt;h2 id=&quot;들어가며&quot;&gt;들어가며&lt;/h2&gt;

&lt;p&gt;본 논문은 2015년도, 즉 동년도에 같은 저자가 저술한 From neural PCA to deep unsupervised learning을 발전시킨 논문입니다. 두 논문 모두 준지도학습을 위한 Ladder Network라는 새로운 구조를 도입했으며 이전 논문은 단순히 비지도학습에서만 실험을 진행했다면 오늘 소개할 논문은 지도학습에서 실험을 진행했다는 차이점이 있습니다. 이 논문에서는 기존에 비지도학습이 지도학습을 돕기 위해 사용될때 단순히 pre training 단계에서만 사용되는 모습을 지양하고 마치 supervised learning처럼 계속 학습을 진행하여 비지도학습 부분에서도 데이터가 가지는 다양한 특징을 활용할 수 있었다는 특징을 보입니다.&lt;/p&gt;

&lt;p&gt;제가 파악한 Ladder Network의 도입을 위해 진행된 생각의 흐름은 다음과 같습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/nLOQIGm.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림 1 &lt;/p&gt;

&lt;p&gt;일반적으로 준지도학습을 위해서는 잠재 변수 모델이 선호됩니다. 잠재 변수 모델은 주어진 데이터의 특징을 표현할 수 있는 핵심적인 변수가 존재한다고 가정합니다. 지도학습과 비지도학습의 차이는 학습에 필요한 label의 여부만 나기에 잠재 변수 모델을 사용한다면 지도학습과 비지도학습을 효과적으로 사용할 수 있게됩니다.&lt;/p&gt;

&lt;p&gt;그러나 단층 잠재 변수 모델은 데이터의 계층적인 특징을 효과적으로 반영할 수 없습니다. 이는 사람의 얼굴을 구분할때 계층적 모델은 얼굴의 윤곽부터 구체적인 눈, 코, 입과 같은 지엽적인 정보부터 사람의 얼굴을 구성하는 핵심적인(불변하는) 특징까지 단층적인 모델로 표현할 수 없다는 것과 동일합니다. 따라서 여러 층을 활용하는 계층적 잠재 변수 모델(Hierarchical latent variable model)이 이상적입니다.&lt;/p&gt;

&lt;p&gt;다만 계층적 잠재 변수 모델은 계산하는 데 있어서 많은 단점을 가지고 있습니다. 따라서 본 논문에서는 비지도학습 방법론인 Autoencoder의 구조를 사용하여 데이터의 계층적인 특징을 활용하고자 하였습니다. 기본적인 오토인코더는 단층 구조이기에 단층 잠재 변수 모델과 같은 단점을 가지고 있으므로 여러 개의 layer를 쌓아서(Stacked) 활용합니다.&lt;/p&gt;

&lt;p&gt;Stacked Autoencoder는 그 구조면에서 본다면 앞서 언급한 Hierarchical latent variable model과 유사합니다. 그러나 큰 단점이 있는데요. 오토인코더에서 아래층에서 위의 층으로의 연결은 확률적(Stochastic)이지 않고 결정적(Deterministic)하다는 점에 있습니다. 오토 인코더에서 층 간의 연결이 매핑함수로 구성되어있기 때문입니다.&lt;/p&gt;

&lt;p&gt;본 논문에서는 이러한 Autoencoder의 단점을 개선하고 데이터의 계층적인 특징을 활용하기 위해서 인코더와 디코더를 수평적으로 연결하여 네트워크를 구성합니다. 마치 그 모양이 사다리를 닮아서 Ladder network라고 부르게 되는 것이지요. 아래 그림 2는 지금까지 언급한 세 구조를 단편적으로 보여주고 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/GCKMagz.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림 2 &lt;/p&gt;

&lt;p&gt;결론적으로 말하자면, 효율적인 준지도학습을 위해서 Hierarchical latent variable 모델의 특징을 반영하는 Autoencoder Network를 구축하였고 그 과정에서 더욱 효과적인 학습을 위해 Denoising 기법을 적극 활용한 결과물이 바로 Ladder Network가 되겠습니다.&lt;/p&gt;

&lt;h2 id=&quot;denoising&quot;&gt;Denoising&lt;/h2&gt;

&lt;p&gt;본 논문에서 또 다른 핵심이 되는 부분이 바로 Denoising입니다. Denoising, 즉 잡음 제거 방식은 잡음을 추가한 데이터를 학습하여 데이터가 가지고 있는 본래의 고유한 특징을 더 잘 찾기 위한 방법입니다. 크게 Denoising frame work 와 Denoising Autoencoder 부분으로 얘기할 수 있습니다. 만일 주어진 데이터 X에 대한 확률 모델을 알고 있다면, 효과적인 샘플링 및 Denoising function을 활용할 수 있을 것입니다. 즉, optimal한 denoising function을 사용하는 것이 좋은 샘플링 및 데이터 x를 파악하는데 도움이 될 수 있다는 의미가 되는데 확률 모델을 구하는 것보다 Denoising 기법을 사용하는 것이 더 쉽기 때문에 Ladder Network에서 사용했다고 이해하면 되겠습니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Denoising Autoencoder&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Denoising Autoencoder는 기본적인 Autoencoder 구조에서 Input data에 노이즈를 추가하여 학습을 진행하는 비지도학습을 말합니다. 오토인코더가 인풋을 잘 복원하는 아웃풋을 학습한다면 dAE는 노이즈가 추가된 인풋 데이터가 압축(encode)되고 복원(reconstruct)되는 과정 후 생성된 output이 노이즈가 없는 최초 상태의 인풋과 최대한 유사하도록 학습하게됩니다. 흔히 알려진 예로 안개 속에서 사람을 구별하는 예시를 들 수 있습니다. 안개 속에서 사람이 우산을 쓰고 있더라도 우리는 그 사람을 판별할 수 있습니다. 안개로 인해 우리의 시각이 조금 방해받지만, 사람임을 판별할 수 있는 핵심 요소는 반대로 강하게 작용하기 때문이죠. 이 같은 원리를 오토인코더에 적용한 방법론이 dAE인 것입니다. 
Ladder network에서 dAE구조는 마치 지도학습과 같이 학습이 진행됩니다. 즉, 노이즈를 추가한 오염된 경로(corrupted path)의 아웃풋과 타겟값의 loss를 줄이는 방향으로 학습이 진행됩니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Denoising Source Separation framework&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Denoising Source Separation framework (DSS framework)는 저자의 관련 연구인 Denoising Source Separation에서 사용된 Denoising 도입 방식을 의미합니다. dAE가 output,즉 출력물과의 관계를 학습한다면 DSS framework를 통해서 잠재변수간의 관계를 학습합니다. Denoising의 특징에 맞게 노이즈가 추가된 데이터로부터 만들어지는 잠재 변수는 실제 깨끗한(clean path)로 인해 만들어지는 잠재 변수와 최대한 유사하게 학습됩니다. DSS framework에서는 잠재변수(z)의 normalization을 요구하는데 논문에서는 이를 Batch Normalization으로 해결합니다.&lt;/p&gt;

&lt;h2 id=&quot;ladder-network&quot;&gt;Ladder Network&lt;/h2&gt;

&lt;p&gt;다음은 Ladder Network의 구조를 나타냅니다. 앞서 언급했듯이 corrupted path, clean path, denoising path로 구성되어 있으며 학습 또한 지도학습과 비지도학습이 결합되어 진행됩니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/qn1z3Zw.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림 3 &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/OqucsZE.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림 4 &lt;/p&gt;

&lt;h2 id=&quot;implementation-of-the-model&quot;&gt;Implementation of the model&lt;/h2&gt;

&lt;p&gt;Ladder Network를 모델에 도입하는 과정은 세 단계로 이루어집니다.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;1: Encoder로 지도 학습을 하는 feedforward 모델 구축&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;feedforward 모델에서는 지도학습과 비지도학습이 병렬적으로 진행됩니다. 앞서 DSS framework를 위해 batch normalization을 사용한다고 언급했습니다. 논문에서는 Batch normalization을 통해 BN이 주는 효과인 covariance shift를 줄이는 동시에 일반화 가정을 만족시킬 수 있다고 언급합니다. Feedforward 모델 구축을 통해서 지도학습과 비지도학습이 병렬적으로 진행되며 지도학습 측면에서는 corrupted path에서 생성된 output이 실제 target과 유사하도록 학습이 진행됩니다. 그림 4의 파란색으로 표시된 지도학습에 해당하는 부분입니다.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;2: 각 층과 mapping하고 비지도학습을 돕는 decoder 구축&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;두 번째 단계에서는 비지도학습이 진행됩니다. 즉 1번 단계에서 encoder단에서 학습된 비지도학습 가중치가 decoder단으로 내려오면서 아래층의 잠재변수 학습에 영향을 주고 동시에 수평적으로 연결된 corrupted path의 같은 층의 정보가 영향을 주면서 clean path의 z와 유사하도로 학습이 진행됩니다.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;3: 모든 손실합수의 합을 최소화하는 Ladder Network 학습&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;지도학습의 loss function과 비지도학습의 loss function을 합친 최종 loss function이 작아지도록 학습이 진행됩니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;CNN 구조로의 확장&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;지금까지의 Ladder Network의 도입은 여러 층을 가진 MLP(Multi-layer perceptron)을 사용했습니다. 그러나 Ladder Network는 CNN과 같이 기존에 존재하는 뉴럴 네트워크에 쉽게 도입할 수 있습니다. CNN 구조에서도 마찬가지로 encoder 파라미터의 흐름을 역으로 반영한 decoder를 구축하는 방식으로 쉽게 사용할 수 있습니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;감마 모델&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Decoder의 가장 높은 layer만을 사용하는 모델을 감마 모델이라고 부르며 상대적으로 더 간단한 모형이지만 실험에서도 좋은 성능을 보이고 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/pBxgvc6.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 감마 모델 &lt;/p&gt;

&lt;h2 id=&quot;실험-결과&quot;&gt;실험 결과&lt;/h2&gt;

&lt;p&gt;실험에서는 MNIST 데이터와 CIFAR-10 데이터를 사용하였고, MLP 네트워크와 서로 다른 CNN 네트워크를 사용하였습니다. CIFAR-10 데이터에서는 감마 모델만을 사용한 실험을 진행했습니다. 실험에 따라 사용한 label의 수에 변화를 주어 지도학습을 진행했지만 비지도학습은 전체 데이터를 모두 사용하였습니다. 공통적으로 기존의 결과와 비교했을때 SOTA의 성능을 보임을 확인할 수 있습니다. MNIST의 경우에는 label을 100개만 사용했을 경우에도 좋은 실험 결과를 보여주고 있음을 알 수 있습니다. MNIST데이터에 대해 워낙 좋은 성능을 보이는 뉴럴 네트워크 구조들이 많지만 label수를 적게 사용해도 이 정도의 성능을 보인다는 점은 주목할만 하겠습니다. 또한 감마 모델만을 사용한 CIFAR-10에서는 감마 모델만으로도 충분히 좋은 성능을 보이고 있음이 확인됩니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/8wBhqx9.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; MNIST + MLP &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/yT52obf.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; MNIST + CNN &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/FWrnwS7.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; CIFAR-10 + CNN &lt;/p&gt;

&lt;h2 id=&quot;마치며&quot;&gt;마치며&lt;/h2&gt;

&lt;p&gt;아직까지 학습에 필요한 데이터를 쉽게 구하기는 쉽지 않습니다. 작년 하반기, 몇달전만 하더라도 저희 연구실에서 양질의 데이터를 구축하기 위해 연구원들 모두가 개인시간을 꼬박 할애하며 끝없는 작업을 진행했던 기억이 생생합니다. 앞으로도 제대로된 label이 달린 데이터를 구하기란 쉽지 않을 것입니다. 이런 측면에서 비지도학습의 풍부한 정보를 지도학습의 Task의 걸맞게 적용할 수 있다면 딥러닝은 한층 더 진보할 수 있을 거라 믿습니다. Ladder network는 연구의 목적과 더불어 목적 해결을 위해 창의적으로 수평적 연결을 시도함으로써 재미있게 읽을 수 있었던 논문이었습니다. 동시에 이 논문을 공부하면서 수학적인 베이스에 대한 필요성도 절실히 느꼈는데 이러한 부족함과 창의성을 향후 추가적인 노력을 통해 채울 수 있기를 바라고 있습니다.&lt;/p&gt;
</description>
        <pubDate>Sun, 13 Jan 2019 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/Ladder-Network.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/Ladder-Network.html</guid>
        
        <category>Semi-supervised learning</category>
        
        
        <category>Semi-supervised</category>
        
        <category>learning</category>
        
      </item>
    
      <item>
        <title>PytorchZeroToAll(4)</title>
        <description>&lt;h1 id=&quot;pytorchzerotoall4&quot;&gt;PytorchZeroToAll(4)&lt;/h1&gt;

&lt;p&gt;안녕하세요. 이번 포스트를 마지막으로 김성훈 교수님의 PytorchToAll 강의를 끝내려합니다. 마지막 포스트에서는 시계열 처리에 유효한 대표적인 인공신경망 구조인 RNN을 다루도록 하겠습니다. 본 포스트에서 사용된 코드는 유투브에 있는 Sung Kim님의 강의에서 가져왔음을 밝힙니다.&lt;/p&gt;

&lt;h2 id=&quot;rnn&quot;&gt;RNN&lt;/h2&gt;

&lt;p&gt;RNN은 Recurrent Neural Network, 즉 자기 회기적인 인공신경망을 의미합니다. RNN이 다른 인공신경망과 가지는 가장 큰 특징은 한 시점의 학습이 다른 시점의 학습의 영향을 받는다는점에 있습니다. 다양한 RNN 구조를 pytorch에서는 다음과 같이 표현할 수 있습니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;cell&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;RNN&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_first&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;cell&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;GRU&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_first&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;cell&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;LSTM&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_first&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;p&gt;hidden size는 output size와 동일하며, input단에서는 batch_first를 True값을 줍니다.&lt;/p&gt;

&lt;p&gt;기본적인 input과 output을 만드는 과정은 아래 코드와 같습니다.&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.nn&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.autograd&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# One hot encoding for each char in 'hello'
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;e&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;l&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;o&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# One cell RNN input_dim (4) -&amp;gt; output_dim (2). sequence: 5
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cell&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;RNN&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_first&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# (num_layers * num_directions, batch, hidden_size) whether batch_first=True or False
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Propagate input through RNN
# Input: (batch, seq_len, input_size) when batch_first=True
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cell&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;sequence input size&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;out size&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;다음-character를-예측하는-기본적인-rnn-모델&quot;&gt;다음 character를 예측하는 기본적인 RNN 모델&lt;/h2&gt;

&lt;p&gt;hihell을 인풋으로, ihello를 output으로 하는 기본적인 RNN 모델을 구축해봅시다. 위의 과정과 동일하게 one-hot encoding을 거칩니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;sys&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.nn&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.autograd&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;idx2char&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'h'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'i'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'e'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'l'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'o'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Teach hihell -&amp;gt; ihello
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;   &lt;span class=&quot;c1&quot;&gt;# hihell
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;one_hot_lookup&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 0
&lt;/span&gt;                  &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 1
&lt;/span&gt;                  &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 2
&lt;/span&gt;                  &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 3
&lt;/span&gt;                  &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 4
&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# ihello
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_one_hot&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;one_hot_lookup&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# As we have one batch of samples, we will change them to variables only once
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_one_hot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;LongTensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;p&gt;그 후 hyperparameter를 세팅합니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;num_classes&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;input_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# one-hot size
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# output from the RNN. 5 to directly predict one-hot
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;   &lt;span class=&quot;c1&quot;&gt;# one sentence
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sequence_length&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# One by one
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_layers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# one-layer rnn
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;이후 RNN을 사용하는 모델을 구축합니다. Input x에 대한 output의 shape을 주의해야합니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;rnn&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;RNN&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                          &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_first&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Reshape input (batch first)
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sequence_length&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;input_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# Propagate input through RNN
&lt;/span&gt;        &lt;span class=&quot;c1&quot;&gt;# Input: (batch, seq_len, input_size)
&lt;/span&gt;        &lt;span class=&quot;c1&quot;&gt;# hidden: (num_layers * num_directions, batch, hidden_size)
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;rnn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_classes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Initialize hidden and cell states
&lt;/span&gt;        &lt;span class=&quot;c1&quot;&gt;# (num_layers * num_directions, batch, hidden_size)
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;


&lt;span class=&quot;c1&quot;&gt;# Instantiate RNN model
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;모델의 학습을 위한 loss function을 도입합니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Set loss and optimizer function
# CrossEntropyLoss = LogSoftmax + NLLLoss
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;criterion&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CrossEntropyLoss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Adam&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;반복을 통한 학습을 진행합니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Train the model
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;epoch&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zero_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init_hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;sys&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;stdout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;write&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;predicted string: &quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;label&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# print(input.size(), label.size())
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;val&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;sys&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;stdout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;write&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx2char&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]])&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;criterion&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;label&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;, epoch: %d, loss: %1.3f&quot;&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;epoch&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]))&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;backward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Learning finished!&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;위와는 다르게 loop를 돌리지 않고 더 쉽게 각 sequence 마다 학습을 진행할수도 있습니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Train the model
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;epoch&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;outputs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rnn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zero_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;criterion&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;outputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;backward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;outputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;numpy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;result_str&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx2char&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;squeeze&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;epoch: %d, loss: %1.3f&quot;&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;epoch&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]))&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Predicted string: &quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;''&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;join&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;result_str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Learning finished!&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;강의에서는 one hot보다 embedding 방식으로 학습을 진행하면 더 좋은 성능을 보임을 말합니다. 이는 0,1로 값이 고정된 one-hot보다 embedding을 통해 학습이 더욱 유연하게 진행되기 때문입니다.&lt;/p&gt;

&lt;h2 id=&quot;rnn-for-classification&quot;&gt;RNN for classification&lt;/h2&gt;

&lt;p&gt;RNN은 many to one, one to many, many to many와 같은 다양한 task에 사용할 수 있습니다. 강의에서는 특히 many to one 상황의 classification task에서 RNN이 어떻게 활용될 수 있는지를 보여줍니다.&lt;/p&gt;

&lt;p&gt;예시로 사용한 Name classification에서 Input으로 one-hot이 아닌 아스키 코드를 적용한 embedding을 활용합니다. 이 과정은 아래의 코드와 같습니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;str2ascii_arr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;msg&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;arr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;ord&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;msg&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;arr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;RNNClassifier&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embedding&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;embedded&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;p&gt;many to one을 반영한 모델 구축 부분은 아래와 같습니다.&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;RNNClassifier&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;input_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_layers&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;RNNClassifier&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_layers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_layers&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embedding&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gru&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;GRU&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Note: we run this all at once (over the whole input sequence)
&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# input = B x S . size(0) = B
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# input:  B x S  -- (transpose) --&amp;gt; S x B
&lt;/span&gt;        &lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# Embedding S x B -&amp;gt; S x B x I (embedding size)
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;  input&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;embedded&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;  embedding&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;embedded&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# Make a hidden
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_init_hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gru&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embedded&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;  gru hidden output&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Use the last layer output as FC's input
&lt;/span&gt;        &lt;span class=&quot;c1&quot;&gt;# No need to unpack, since we are going to use hidden
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;fc_output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;  fc output&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fc_output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fc_output&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_init_hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;만일 batch를 사용하여 한번에 여러 데이터를 처리한다면 각 이름의 length가 달라서 적합한 padding과정이 필요합니다. 이를 해결하기 위해 zero padding을 사용합니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# pad sequences and sort the tensor
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;pad_sequences&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;vectorized_seqs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_lengths&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;seq_tensor&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;vectorized_seqs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_lengths&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())).&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;long&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;enumerate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;vectorized_seqs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_lengths&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;seq_tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;LongTensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_tensor&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Create necessary variables, lengths, and target
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_variables&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;names&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sequence_and_length&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;str2ascii_arr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;name&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;names&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;vectorized_seqs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sl&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sequence_and_length&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;seq_lengths&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;LongTensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sl&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sequence_and_length&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;pad_sequences&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;vectorized_seqs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_lengths&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;p&gt;이후 학습을 진행하면 됩니다.&lt;/p&gt;

&lt;p&gt;추가적으로 여러 배치를 사용할 경우 pack_padded_sequence라는 기능을 사용하면 여러 배치를 unpack하여 한번에 처리한 후 다시 pack하는 과정을 통해 효율적인 학습을 시도할 수 있습니다.&lt;/p&gt;

&lt;p&gt;Pytorch에서는 tensor.cuda(), classfier.cuda()를 통해 GPU를 사용할 수 있으며 여러 GPU가 있을 경우에 nn.DataParallel()을 활용해 병렬 처리도 가능합니다. 더 자세히 알아보시려면 pytorch의 tutorial를 참고하시면 됩니다.&lt;/p&gt;

&lt;h2 id=&quot;마치며&quot;&gt;마치며&lt;/h2&gt;

&lt;p&gt;지금까지 PytorchZeroToAll를 들으며 저 나름대로 정리를 하였습니다. 많이 부족하지만 딥러닝을 공부하시면서 pytorch를 처음 시작하는 분들에게 좋은 자료가 되었으면 하는 바람입니다. 동시에 좋은 자료를 무상으로 공유하여 많이 배울 수 있는 기회를 주신 김성훈님께 깊은 감사의 말씀을 보냅니다. 앞으로도 좋은 자료로 블로그에서 찾아뵙겠습니다.&lt;/p&gt;
</description>
        <pubDate>Sun, 13 Jan 2019 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/pytorch4.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/pytorch4.html</guid>
        
        <category>Pytorch</category>
        
        
        <category>Pytorch</category>
        
      </item>
    
      <item>
        <title>PytorchZeroToAll(3)</title>
        <description>&lt;h1 id=&quot;pytorchzerotoall3&quot;&gt;PytorchZeroToAll(3)&lt;/h1&gt;

&lt;p&gt;안녕하세요 이번 포스트에서는 지난번에 이어 김성훈 교수님의 PytorchToAll 강의를 게시하도록 하겠습니다. 이번 포스트에서는 대표적인 인공신경망 구조인 CNN을 다루도록 하겠습니다. 본 포스트에서 사용된 코드는 유투브에 있는 Sung Kim님의 강의에서 가져왔음을 밝힙니다.&lt;/p&gt;

&lt;h2 id=&quot;cnn&quot;&gt;CNN&lt;/h2&gt;
&lt;p&gt;CNN은 대표적인 인공 신경망입니다. 가장 직관적으로 CNN을 접근한다면 높은 차원의 데이터를 특정 구간에 집중하여(patch) 작은 차원으로 줄이는 동시에 데이터의 특징을 잘 포착하는 인공신경망이라고 표현할 수 있겠습니다. Convolution, subsampling 단계를 거치면서 차원이 축소되고 특징을 잡아냅니다.(Feature extraction) 이후 Fully connected layer를 거쳐 task에 적합한 output을 산출합니다.(classification)&lt;/p&gt;

&lt;p&gt;convolution layer에서는 stride만큼 움직이며 전체의 부분(patch)의 특징을 잡아냅니다. 가장자리에 0을 추가하여 filter를 적용하는 zero padding을 사용하기도 합니다.&lt;/p&gt;

&lt;p&gt;Pooling은 filter가 stride만큼 이동하면서 부분의 대표적인 특징을 잡는 역할을 합니다. Max pooling, average pooling등을 사용합니다.&lt;/p&gt;

&lt;p&gt;Fully connected neural net과 locally connected neural net의 가장 큰 차이점은 이름에서 나와 있듯이 전자는 주어진 모든 뉴런을 종합한 output을 계산하는 것이고 후자는 각 부분별로 뉴런의 아웃풋을 산출한다는 점입니다. 필연적으로 두 방법간에는 파라미터수가 크게 차이나게 됩니다.&lt;/p&gt;

&lt;p&gt;아래의 코드는 두 개의 convolution layer, 두 개의 pooling layer를 사용한 CNN 구조입니다.&lt;/p&gt;

&lt;p&gt;convolution layer인 Conv2d는 Conv2d(in_channels, out_channels, kernel_size)입니다. 여기서 2d란 말 그대로 convolution layer가 2차원임을 의미합니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Net&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Net&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;20&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mp&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MaxPool2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;320&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;in_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;in_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# flatten the tensor
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;log_softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Net&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;advanced-cnn&quot;&gt;Advanced CNN&lt;/h2&gt;

&lt;p&gt;기본적인 CNN 구조에서 더 발전한 구조를 다루겠습니다. 강의에서는 Inception 모델을 다루고 있습니다. Inception 모델의 가장 큰 특징은 1by1 kernel size를 사용해서 계산을 용이하게 만들었다는 점입니다. 1by1 Conv를 사용해서 다양한 kernel size를 적용하여 모델의 feature extraction을 시도한 모델이라고 생각하면 되겠습니다.&lt;/p&gt;

&lt;p&gt;효율적인 코드 작성을 위해서 다음과 같이 InceptionA라 명명한 class를 따로 설정하고 이를 네트워크를 구축하는데 사용합니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;InceptionA&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;in_channels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;InceptionA&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch1x1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;in_channels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch5x5_1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;in_channels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch5x5_2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;24&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch3x3dbl_1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;in_channels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch3x3dbl_2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;24&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch3x3dbl_3&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;24&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;24&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch_pool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;in_channels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;24&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;branch1x1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch1x1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;branch5x5&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch5x5_1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;branch5x5&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch5x5_2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch5x5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;branch3x3dbl&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch3x3dbl_1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;branch3x3dbl&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch3x3dbl_2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch3x3dbl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;branch3x3dbl&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch3x3dbl_3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch3x3dbl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;branch_pool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;avg_pool2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;branch_pool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;outputs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;branch1x1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;branch5x5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;branch3x3dbl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;branch_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;outputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Net&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Net&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;88&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;20&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;incept1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;InceptionA&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;in_channels&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;incept2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;InceptionA&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;in_channels&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;20&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mp&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MaxPool2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1408&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;in_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;incept1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;incept2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;in_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# flatten the tensor
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;log_softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Net&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;CNN 구조를 더 깊게 쌓으면 학습이 더 좋아질까요? 안타깝게도 무조건 깊게 쌓는 것만이 좋은 성능을 담보하지는 않습니다. Gradient Vanishing problem, Overfitting, Degradation problem과 같이 여러 문제로 학습이 잘 진행되지 않게됩니다. 이를 보완하기 위해서 skip-connenction을 도입하여 back propagation 문제를 개선한 ResNet, 이전 층의 output을 활용한 DenseNet같은 구조가 제안되었습니다. 추후 설명할 기회를 갖도록 하겠습니다.&lt;/p&gt;
</description>
        <pubDate>Sat, 12 Jan 2019 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/pytorch3.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/pytorch3.html</guid>
        
        <category>Pytorch</category>
        
        
        <category>Pytorch</category>
        
      </item>
    
      <item>
        <title>PytorchZeroToAll(2)</title>
        <description>&lt;h1 id=&quot;pytorchzerotoall2&quot;&gt;PytorchZeroToAll(2)&lt;/h1&gt;

&lt;p&gt;안녕하세요 이번 포스트에서는 지난번에 이어 김성훈 교수님의 PytorchToAll 강의를 게시하도록 하겠습니다. 오늘은 logistic regression부터 다루며 pytorch를 활용한 neural network 구조와 실험을 위한 데이터 세팅법을 살펴보겠습니다. 본 포스트에서 사용된 코드는 Sung Kim님의 강의에서 가져왔음을 밝힙니다.&lt;/p&gt;

&lt;h2 id=&quot;logistic-regression&quot;&gt;logistic regression&lt;/h2&gt;

&lt;p&gt;앞서 다룬 linear 모델과의 차이점은 activation function인 sigmoid function을 사용하고, loss function으로 cross-entropy를 사용한다는 점입니다.&lt;/p&gt;

&lt;p&gt;코드로 살펴본다면 y_pred와 criterion 부분이 달라졌음을 확인할 수 있습니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.autograd&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.nn.functional&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;x_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;2.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;3.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;4.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]))&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
        In the constructor we instantiate nn.Linear module
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linear&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# One in and one out
&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
        In the forward function we accept a Variable of input data and we must return
        a Variable of output data.
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;y_pred&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sigmoid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_pred&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;criterion&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BCELoss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size_average&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;SGD&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.01&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;wide--deep&quot;&gt;Wide &amp;amp; Deep&lt;/h2&gt;

&lt;p&gt;위의 logistic regression을 더 넓게(Wide), 그리고 깊게(Deep) 만들어서 사용해봅시다. 
코드에서 나온것처럼 차원에 맞게 여러 층을 추가해 간단하게 사용할 수 있습니다. 단순 로지스틱과의 차이는 input 차원이 1-&amp;gt;8로 변화하였고 사용하는 layer도 l1,l2,l3로 깊어졌습니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;

&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
        In the constructor we instantiate two nn.Linear module
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;6&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;6&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l3&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sigmoid&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Sigmoid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
        In the forward function we accept a Variable of input data and we must return
        a Variable of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Variables.
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sigmoid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sigmoid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;out1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;y_pred&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sigmoid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;out2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_pred&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# our model
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;


&lt;span class=&quot;c1&quot;&gt;# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;criterion&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BCELoss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size_average&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;SGD&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;dataloader&quot;&gt;DataLoader&lt;/h2&gt;

&lt;p&gt;데이터 수가 많을 경우, 전체를 한번에 학습하는 것은 굉장히 비효율적입니다. 따라서 데이터를 여러 배치로 나누어서 학습을 진행합니다. Pytorch를 사용한다면 데이터를 배치 단위로 나누는 과정을 쉽게 해결할 수 있습니다.&lt;/p&gt;

&lt;p&gt;주어진 데이터를 다루는 Dataloader는 세 단계로 구성됩니다. 코드와 함께 설명하겠습니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.autograd&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.utils.data&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DataLoader&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;DiabetesDataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot; Diabetes dataset.&quot;&quot;&quot;&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Initialize your data, download, etc.
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;xy&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loadtxt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'./data/diabetes.csv.gz'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                        &lt;span class=&quot;n&quot;&gt;delimiter&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;','&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;xy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;from_numpy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;from_numpy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]])&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__getitem__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;index&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;index&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;index&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__len__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;dataset&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DiabetesDataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;ul&gt;
  &lt;li&gt;1: download, read data, etc.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;데이터를 가져오고, 읽는 단계입니다. numpy array를 torch로 변경하는 과정이 수반됩니다.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;2: return one item on the index&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;주어진 index에 해당하는 데이터를 가져오는 기능입니다.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;3: return the data length&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;데이터의 length를 반환하는 인자입니다.&lt;/p&gt;

&lt;p&gt;이렇게 Dataset을 설정한 후 DataLoader함수를 사용하여 배치 사이즈에 맞게 데이터를 가져올 수 있습니다. Pytorch는 MNIST, COCO와 같은 유명한 데이터셋을 따로 지원하기에 개별적으로 다운로드하거나 custom할 필요 없이 데이터를 가져올 수 있습니다. 아래 코드에서 num_workers는 병렬 프로세싱을 위한 옵션인데 현재(2019-01-13)까지는 윈도우 운영체제에서는 작동하지 않는것 같습니다. 따라서 윈도우 OS에서는 0으로 두시는걸 권장합니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;dataset&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DiabetesDataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;train_loader&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DataLoader&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dataset&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                          &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                          &lt;span class=&quot;n&quot;&gt;shuffle&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                          &lt;span class=&quot;n&quot;&gt;num_workers&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;softmax-classfier&quot;&gt;Softmax Classfier&lt;/h2&gt;

&lt;p&gt;Softmax를 사용하는 이유는 softmax를 통해서 output의 probability를 구할 수 있기 때문입니다. 물론 softmax가 절대적인 정답은 아니겠지만 하나의 clssifier로써 굉장히 좋은 성능을 자랑합니다. 주로 softmax를 적용한 아웃풋과 실제 정답 타겟과의 cross-entropy loss function을 사용하며 Multi label prediction task에서 활용합니다.&lt;/p&gt;

&lt;p&gt;Pytorch에서는 cross-entropy loss를 간단히 사용할 수 있습니다. 주의해야할 점은 실제 타겟값 Y가 one-hot 값이 아닌 class 값으로 들어가야 한다는 점과 예측값 y_pred가 softmax를 적용한 결과물이 아닌 logit값 그 자체가 들어가야 한다는점입니다. nn.CrossEntropyLoss()함수에 이미 softmax가 내장되어있기 때문입니다. 여러모로 사용자 친화적인 결과물이라고 볼 수 있겠습니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.nn&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.nn.functional&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.optim&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;optim&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torchvision&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;datasets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;transforms&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.autograd&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;


&lt;span class=&quot;c1&quot;&gt;# Softmax + CrossEntropy (logSoftmax + NLLLoss)
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CrossEntropyLoss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;n 사이즈의 배치 또한 한번에 사용할 수 있습니다. 사이즈에 맞게 Y와 y_pred가 주어지면 됩니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;
&lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;LongTensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# input is of size nBatch x nClasses = 2 x 4
# Y_pred are logits (not softmax)
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Y_pred1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.9&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                                 &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                                 &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;2.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]))&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;Y_pred2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                                 &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                                 &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]))&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;l1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Y_pred1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;l2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Y_pred2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;p&gt;아래의 코드는 MNIST dataset을 사용하여 4개의 hidden layer로 구축한 뉴럴 네트워크 예시입니다. activation function으로는 relu를 사용했습니다. Train data와 Test data를 따로 불러와 최종 Accuracy를 구할 수 있습니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# https://github.com/pytorch/examples/blob/master/mnist/main.py
&lt;/span&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;__future__&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;print_function&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.nn&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.nn.functional&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.optim&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;optim&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torchvision&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;datasets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;transforms&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.autograd&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Training settings
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# MNIST Dataset
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_dataset&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;datasets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MNIST&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;root&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'./mnist_data/'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                               &lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                               &lt;span class=&quot;n&quot;&gt;transform&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transforms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ToTensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt;
                               &lt;span class=&quot;n&quot;&gt;download&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;test_dataset&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;datasets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MNIST&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;root&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'./mnist_data/'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                              &lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                              &lt;span class=&quot;n&quot;&gt;transform&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transforms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ToTensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Data Loader (Input Pipeline)
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_loader&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;utils&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;DataLoader&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dataset&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                           &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                           &lt;span class=&quot;n&quot;&gt;shuffle&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;test_loader&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;utils&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;DataLoader&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dataset&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;test_dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                          &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                          &lt;span class=&quot;n&quot;&gt;shuffle&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Net&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Net&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;784&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;520&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;520&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;320&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l3&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;320&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;240&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l4&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;240&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;120&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l5&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;120&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;784&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Flatten the data (n, 1, 28, 28)-&amp;gt; (n, 784)
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Net&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;criterion&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CrossEntropyLoss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;optim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;SGD&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.01&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;momentum&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
</description>
        <pubDate>Fri, 11 Jan 2019 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/pytorch2.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/pytorch2.html</guid>
        
        <category>Pytorch</category>
        
        
        <category>Pytorch</category>
        
      </item>
    
      <item>
        <title>learning by association - a versatile semi-supervised training method for neural networks</title>
        <description>&lt;p&gt;안녕하세요 오늘 다룰 논문은 learning by association - a versatile semi-supervised training method for neural networks 입니다. 2017 CVPR에서 발표된 논문이며 
본 블로그에 올라올 Ladder network와의 비교가 진행된 논문입니다.&lt;/p&gt;

&lt;h2 id=&quot;들어가면서&quot;&gt;들어가면서&lt;/h2&gt;

&lt;p&gt;본 논문은 인간이 학습하는 방식을 참고한 신경망 구조를 제안합니다. 인간은 데이터간의 연관성(Association)를 통해서 학습이 가능합니다. 어린 아이의 경우에도 몇 개의 강아지 사진을 보고 강아지란걸 학습한다면 다른 강아지 사진을 보더라도 그것이 강아지인지 아닌지를 알 수 있을 것입니다. 반면 인공신경망의 경우, 좋은 성능을 보이는 구조를 학습하기 위해서는 학습에 사용되는 모든 데이터마다 그에 해당하는 레이블이 존재해야 합니다. 본 논문은 인간이 학습하는 연관성의 특징을 네트워크에 적용시킨 결과물입니다(Learning by association).&lt;/p&gt;

&lt;p&gt;Learning by association의 간략한 과정은 다음과 같습니다. 하나의 배치에 해당하는 label 데이터와 unlabel 데이터의 임베딩을 만듭니다. 그 후 label 배치의 샘플로부터 Imaginary walker가 unlabel 배치의 샘플로 전달됩니다. 이 전이과정(transition)은 각각의 임베딩의 유사도(similarity)에서 얻은 확률 분포에 따라 진행되며 이것을 association이라 부릅니다.&lt;/p&gt;

&lt;p&gt;이러한 association이 합당하게 일어나는지 평가하기 위해서 임베딩간의 유사도에 따라 label 배치로의 역전환 과정이 진행됩니다. 만일 처음의 class와 동일하게 판별한다면 두 배치가 유사한 class임을 알 수 있을 것입니다. 네트워크의 목적은 서로 연관성이 없는, 즉 다른 class의 데이터들간의 특징(essence)를 잘 잡아내는 것에 있습니다.&lt;/p&gt;

&lt;h2 id=&quot;learning-by-association&quot;&gt;Learning by Association&lt;/h2&gt;

&lt;p&gt;본 논문은 같은 class에 속한다면 좋게 임베딩된 결과물간의 높은 유사성이 있을 것임을 가정합니다. 두 데이터가 CNN을 통해서 임베딩된 벡터로 출력되며 A-&amp;gt;B, B-&amp;gt;A로 왔다가는 walker를 통해 두 데이터간의 관계를 파악합니다. 자세한 과정은 아래 Figure1과 같습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/QcgW47n.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Figure 1 &lt;/p&gt;

&lt;p&gt;따라서 learning by association의 목적은 A-&amp;gt;B-&amp;gt;A로 되는 과정에서 같은 클래스의 두 데이터의 확률을 최대화 하는것입니다.&lt;/p&gt;

&lt;p&gt;우선 두 A,B 임베딩간의 유사도를 계산하는 식을 아래와 같이 정의합니다. 본 논문에서는 가장 좋은 결과를 보인 내적 기법을 사용했지만 다른 방법론도 가능합니다.&lt;/p&gt;

\[M_{ij} := A_i \dot B_j\]

&lt;p&gt;그 다음으로 계산한 유사도를 A에서 B로의 전이 확률(transition probabilites)로 변환합니다.&lt;/p&gt;

\[P^{ab}_{ij} = P(B_j|A_i) := (softmax_{cols}(M))_{ij} = exp(M_{ij}) / \sum_{j'}exp(M_{ij'})\]

&lt;p&gt;반대 방향의 전이 확률(P^{ba}) 또한 M을 M^T로 대체하여 구하며 두 경로의 왕복 확률(round trip probability) 를 구할 수 있습니다.&lt;/p&gt;

\[P^{aba}_{ij} := (P^{ab}P^{ba})_{ij} = \sum_k P^{ab}_{ik}P^{ba}{kj}\]

&lt;p&gt;최종적으로 correct walk에 대한 확률은 아래와 같습니다.&lt;/p&gt;

\[P(correct walk) = \frac{1}{|A|} \sum_{i \sim j} P^{aba}_{ij}\]

&lt;p&gt;CNN구조를 따로 진행한 후 walker를 추가하는 구조이기에 여러 Loss를 사용합니다.&lt;/p&gt;

\[L_{total} = L_{walker} + L_{visit} + L_{classification}\]

&lt;p&gt;각 Loss들에 대한 설명은 다음과 같습니다.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;Walker loss&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;연관성을 알아보는데 있어서 중요한 것은 같은 class를 유지해야한다는 점입니다. 같은 class여야 그 속에서 찾는 관계가 의미가 있기 때문입니다. 따라서 walker loss에서는 부정확한 walk에 대해서 패널티를 부가하며 같은 class에 대해 uniform distribution이 되도록 독려합니다. 자세한 공식은 아래와 같으며 H는 cross entropy를 의미합니다.&lt;/p&gt;

\[L_{walker} := H(T,P^{aba})\]

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/Lx8WRSQ.png&quot; /&gt;&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;Visit loss&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;unlabel 샘플들을 효과적으로 사용하기 위해서는 쉽게 판별할 수 있는 샘플이 아닌 모든 샘플을 사용(visit)해야하며 이를 통해 더 일반화된 임베딩을 얻을 수 있습니다.&lt;/p&gt;

\[L_{visit} := H(V,P^{visit})\]

&lt;p&gt;\(P_j^{visit} := &amp;lt;P_{ij}^{ab}&amp;gt;_i\)
\(V_j = := 1/|B|\)&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;Classification loss&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;임베딩 결과물을 class로 판별하는 기존에 사용하는 네트워크에서의 loss를 classification loss라 부릅니다.&lt;/p&gt;

&lt;h2 id=&quot;experiments&quot;&gt;Experiments&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;MNIST Dataset&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;MNIST 데이터를 사용하여 세세한 finetuning을 거치지 않은 간단한 모델의 실험 결과는 아래 Table 1과 같습니다. vanilla한 구조만으로도 좋은 성능을 보입니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/Iqkh2iD.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Table 1 &lt;/p&gt;

&lt;p&gt;아래 Figure 2는 실험이 진행되면서 연관성을 탐색하는 과정의 발전을 보여줍니다. 
학습 초기에는 임베딩이 좋지 않아 다른 class에도 왕복하는 모습을 보이지만 학습이 완료된 후에는 동일한 class별로 왕복이 이루어져 좋은 학습이 진행되었음을 확인할 수 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/erBYej9.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Figure 2 &lt;/p&gt;

&lt;p&gt;Figure 3은 MNIST 데이터에서 실험 결과 Test error가 어디에서 나타났는지를 보여줍니다. 왼쪽 아래와 같이 틀렸을 경우에도 사람이 판별하기에도 애매한 label이 보임을 확인할 수 있습니다. 즉, MNIST 데이터를 실험한 결과 사람도 속을만큼 애매한 label을 제외하고는 높은 성능을 보여준다는 것을 의미합니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/xRYlo4g.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Figure 3 &lt;/p&gt;

&lt;p&gt;&lt;strong&gt;STL-10 Dataset&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Figure 4는 SLT-10 데이터를 활용한 실험 결과를 보여줍니다. 실험을 통해서 각 class에서 가장 높은 스코어의 결과물 5장을 확인할 수 있습니다. 흥미로운점은 실제 class가 존재하는 자동차와 선박의 경우에는 자동차와 선박을 정확하게 판별했을 뿐만 아니라 train data에 class가 없는 경우에도 주어진 input과 유사한(연관성이 있는) 결과물을 output으로 출력했다는 점입니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/AqF0ktP.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Figure 4 &lt;/p&gt;

&lt;p&gt;&lt;strong&gt;SVHN&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;아래 Table 2,3은 SVHN 데이터를 사용한 실험 결과입니다. 적은 labeled 샘플 수에도 높은 성능을 보여주며 label 데이터, unlabel 데이터를 모두 더 많이 사용할수록 좋은 학습이 이루어짐을 보여줍니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/XvAxkRe.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Table 2 &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/Qspo5zr.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Table 3 &lt;/p&gt;

&lt;p&gt;Table 4는 본 논문에서 제시하는 Visit loss의 효과를 보여주고 있습니다. Visit loss는 데이터 셋에 따라서 결정되어야 하며 너무 큰 visit loss의 경우 과적합의 위험이 있음을 말하고 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/UyFaX8w.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Table 4 &lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Domain Adaptation&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;마지막으로 논문에서는 Donmai Adaptation 측면에서 learning by association 개념을 사용합니다. SVHN 데이터를 source로, MNIST 데이터를 target으로 했을 때, 각 데이터를 독립적으로 학습 한 경우보다 Domain Adaptaion방법론을 적용한 결과물이 가장 좋다는 것을 보여줍니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/ZWPYFos.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; Table 5 &lt;/p&gt;

&lt;h2 id=&quot;마치며&quot;&gt;마치며&lt;/h2&gt;

&lt;p&gt;Deep learning에는 다양한 분야가 존재하며 최근 들어 모든 분야에서의 딥러닝에 대한 연구와 적용을 통해 발전이 이루어지고 있습니다. 다만 연구에 필요한 좋은 데이터에 비해 실제 활용되고 있는 데이터의 질과 양은 아직 많이 부족하다고 생각됩니다. 또한 이러한 상황에서 좋은 품질의 label된 데이터를 구하는 것은 더더욱 어렵습니다. 이러한 측면에서, 비지도학습과 지도학습을 결합하고자 하는 준지도학습은 저 자신에게 굉장히 흥미롭게 다가왔으며 따라서 데이터간의 연관성을 활용하여 문제를 해결하고자 하는 본 논문은 앞으로의 연구에 있어서 많은 참고가 되었습니다.&lt;/p&gt;
</description>
        <pubDate>Thu, 10 Jan 2019 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/learning-by-association-a-versatile-semi-supervised-training-method-for-neural-networks.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/learning-by-association-a-versatile-semi-supervised-training-method-for-neural-networks.html</guid>
        
        <category>Semi-supervised learning</category>
        
        
        <category>Semi-supervised</category>
        
        <category>learning</category>
        
      </item>
    
      <item>
        <title>PytorchZeroToAll(1)</title>
        <description>&lt;h1 id=&quot;pytorchzerotoall&quot;&gt;PytorchZeroToAll&lt;/h1&gt;

&lt;p&gt;안녕하세요 이번 포스트는 Pytorch를 다루도록 하겠습니다. Pytorch는 최근 들어 떠오르고 있는 딥러닝 API입니다. 텐서플로우가 2.0버전으로 업데이트하면서 케라스 기반으로의 변환을 시도하여 기존의 static한 특징에서 벗어나려는 모습을 보입니다. Pytorch는 텐서플로우와 달리 dynamic한 구조를 가지고 있는데요, 지금까지는 pytorch에 대해 잘 몰랐지만 방학도 한 기념으로! 한번 공부를 시작해보려 합니다.&lt;/p&gt;

&lt;p&gt;저는 텐서플로우 공부의 시작을 유투브에 있는 Sung Kim님의 강좌로 시작하였는데요, 오랜만에 다시 그 분의 채널에 들어가보니 파이토치를 다룬 강좌도 생겼더라구요. 따라서 본 포스트는 Sung Kim님의 유투브 강좌 PytorchZeroToAll을 제 식대로 다시 정리한 글이라고 보시면 되겠습니다. 제 나름대의 정리이다보니 체계적이진 않다는 점을 유의하시기 바랍니다. 본 포스트에서 사용된 코드는 Sung Kim님이 작성하신 코드를 그대로 가져왔습니다.&lt;/p&gt;

&lt;h2 id=&quot;to-begin&quot;&gt;To begin&lt;/h2&gt;

&lt;p&gt;Deep learning에서 가장 중요한 것은 정확한 학습을 통해 가중치를 구하는 것입니다. 대표적인 방법으로 Gradient descent 방식이 사용됩니다. 기울기 변화를 통해서 global한 최소의 loss로의 수렴을 기대하는 방식이지요.&lt;/p&gt;

&lt;p&gt;Pytorch가 gradient descent 방식을 적용한 back propagation을 진행하는 방식은 아래와 같습니다. 핵심은 back propagation을 진행하는 l.backward() 와 그 값을 저장하는 w.grad가 되겠습니다&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.autograd&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;x_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;2.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;3.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;2.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;4.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;6.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;w&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]),&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;requires_grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Any random value
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# our model forward pass
&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Loss function
&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;y_pred&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_pred&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_pred&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Before training
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;predict (before training)&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Training loop
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;epoch&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_val&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_val&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_val&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_val&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# backward operation을 진행시키는 함수
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;backward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\t&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;grad: &quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_val&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_val&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.01&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# Manually zero the gradients after updating weights
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zero_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;progress:&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;epoch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# After training
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;predict (after training)&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;pytorch-rhythm&quot;&gt;Pytorch Rhythm&lt;/h2&gt;

&lt;p&gt;파이토치를 사용한 모델 활용은 아래와 같은 세 단계로 진행됩니다.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;1: class와 Variable을 사용한 모델 디자인&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;initializer 부분과 forward 부분으로 구성된 class를 구축합니다. Initializer에서는 모델에 맞는 구조를 생성하는데 본 예제에서는 단순 선형모델을 사용하기에 linear 모델만을 사용했습니다. torch.nn.Linear(1,1) 에서 앞에 1은 Input size를, 뒤의 1은 Output size를 의미합니다. forward 부분에서는 x를 요소로 하여 예측값 y_pred를 출력할 수 있도록 initializer부분에서 정의한 self.Linear 함수를 사용하게됩니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.autograd&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;x_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;2.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;3.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;2.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;4.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;6.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]))&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
        In the constructor we instantiate two nn.Linear module
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linear&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# One in and one out
&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
        In the forward function we accept a Variable of input data and we must return
        a Variable of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Variables.
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;y_pred&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_pred&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# our model
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;ul&gt;
  &lt;li&gt;2: loss 와 optimizer 구축&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;criterion과 optimizer로 대표되는 부분입니다. torch의 API에서 해당하는 loss function과 optimizer를 선정합니다. 본 예제에서는 MSE loss와 SGD optimizer를 사용하였고 optimizer에서 최적화될 대상으로 모델의 파라미터인 model.parameter를 선정하였습니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;criterion&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MSELoss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size_average&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;SGD&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.01&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;ul&gt;
  &lt;li&gt;3: Training cycle 정한 후 학습 진행&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;학습 단계를 의미합니다. forward, backward 단계를 거쳐 epoch만큼 횟수로 파라미터의 최적화가 진행됩니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Training loop
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;epoch&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;500&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Forward pass: Compute predicted y by passing x to the model
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;y_pred&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Compute and print loss
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;criterion&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_pred&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;epoch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Zero gradients, perform a backward pass, and update the weights.
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zero_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;backward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;Testing Model&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;학습이 끝난 후 모델의 성능을 알기 위해 Test를 해야합니다.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# After training
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hour_var&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;4.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y_pred&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hour_var&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;predict (after training)&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hour_var&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;][&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;기본적으로 위와 같은 단계로 진행이 되며 더 복잡한 구조를 구축할 경우에도 첫번째 단계에서 복잡한 모델을 잘 구축한다면 pytorch를 잘 활용할 수 있겠습니다.&lt;/p&gt;

</description>
        <pubDate>Thu, 10 Jan 2019 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/pytorch1.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/pytorch1.html</guid>
        
        <category>Pytorch</category>
        
        
        <category>Pytorch</category>
        
      </item>
    
      <item>
        <title>비선형 차원 축소(Nonliear Mapping)</title>
        <description>&lt;h1 id=&quot;unsupervised-methods-nonlinear-mapping&quot;&gt;Unsupervised Methods: Nonlinear Mapping&lt;/h1&gt;

&lt;p&gt;안녕하세요. 고려대학교 DSBA연구실 석사과정 양우식입니다.&lt;/p&gt;

&lt;p&gt;이번 포스트에서는 비지도학습에서 비선형 매핑 방법론인 Isomap, LLE(Local Linear Embedding), t-SNE 알고리즘을 소개하겠습니다.&lt;/p&gt;

&lt;p&gt;본 포스트는 고려대학교 강필성 교수님의 Business-Analytics 강의와 Nagoya 대학교 Hayashi 교수님의 강의자료, 이 외 다양한 강의를 참고하여 작성되었습니다. 또한 본 포스트에서 소개되는 코드는 고려대학교 DSBA연구실 박사과정 서승완님, 해외의 Liam Schoneveld님의 코드를 참고하였습니다.&lt;/p&gt;

&lt;p&gt;시작하기에 앞서 비지도학습과 차원 축소에 대해서 간략히 설명하고 넘어가도록 하겠습니다. 
비지도 학습은 머신러닝 방법의 일종으로 지도학습, 강화학습과 비교했을 때 인풋, 즉 입력값에 대한 아웃풋인 정답이 존재하지 않는다는 특징을 갖습니다. 대표적으로 차원축소, 분류에 사용되며 차원 축소는 다시 선형 차원 축소, 비선형 차원 축소로 나누어집니다.&lt;/p&gt;

&lt;p&gt;그렇다면 차원 축소는 근본적으로 왜 필요할까요? 
우리가 다루어야 할 데이터가 변수가 매우 많은 고차원일 경우 중요하지 않은 변수로 인해 좋은 성능을 얻지 못할 수 있거나 처리해야 할 데이터의 양 자체가 굉장히 커지게 됩니다. 이때 고차원 데이터를 저차원으로 축소시켜 데이터가 가진 대표적인 특징만을 잘 추출할 수 있다면 더 좋은 성능과 효율적인 작업이 가능해집니다. 차원 축소의 대표적인 알고리즘으로는 분산을 최대로 보존하며 저차원으로 임베딩하는 주성분 분석(PCA)이 있습니다. 
그러나 주어진 데이터의 형태가 그림 1과 같이 매니폴드한 형태를 갖고 있는다면 선형 차원 축소인 주성분 분석으로는 각 색깔로 나타나는 데이터 레이블들만의 고유한 특징을 잡아낼 수 없게 됩니다.
여기서 매니폴드(manifold)란, 고차원 공간 중에 존재하는 실질적으로 보다 저차원으로 표시 가능한 도형을 의미합니다. 대표적인 매니폴드 형태인 스위스 롤을 나타내는 그림 1에서 본다면 고차원 공간으로 보여지는 스위스 롤 도형이 사실은 색깔에 따라 저차원으로 표현이 가능함을 알 수 있습니다.
결론적으로 오늘 다룰 nonlinear mapping 방법론들은 비선형적 차원 축소를 통해 저차원으로도 데이터의 특징을 잘 확보하는데에 그 목적이 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/orBAuKy.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림1 &lt;/p&gt;

&lt;p&gt;이렇게 선형 축소 방법론이 유효하지 못할 경우 Nonlinear Mapping 방법론적 접근이 필요하며 대표적으로 오늘 다룰 Isomap, LLE, t-SNE 방법론이 있겠습니다.&lt;/p&gt;

&lt;hr /&gt;

&lt;h2 id=&quot;isomap&quot;&gt;Isomap&lt;/h2&gt;

&lt;p&gt;Isomap은 다차원 스케일링(MDS) 또는 주성분 분석(PCA)의 확장이자 두 방법론을 결합한 방법론으로 볼 수 있습니다. 앞서 다루었던 PCA와 MDS의 특징을 결합하여 모든 점 사이의 측지선 거리를 유지하는 더 낮은 차원의 임베딩을 추구합니다. 여기서 측지거리란, 두 측점사이의 타원체면을 따라 이루어진 거리를 말합니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/YeDG7pl.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림2 &lt;/p&gt;

&lt;p&gt;그림2에 따르면 두 점은 유클리디안 거리로는 가깝지만 실제 측지거리를 구할 경우 색깔이 나타내는 의미만큼 멀리 떨어져 위치함을 알 수 있습니다. 즉, Isomap 알고리즘은 두 데이터간의 실제 특징을 반영하는 거리 정보를 사용하는 효과적인 차원 축소를 추구합니다.&lt;/p&gt;

&lt;p&gt;Isomap 알고리즘은 세 단계로 구성됩니다. 코드와 함께 설명하며 본 알고리즘에서 사용된 데이터는 그림3과 같은 스위스 롤을 생성하여 적용하였습니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;1.	인접한 이웃 그래프 구축&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;첫번째 단계에서는 어떤 점들이 매니폴드 상에서 서로 가까이 위치하는지를 측정합니다. 두가지 방식이 사용되는데 첫번째는 고정된 기준값인 앱실론을 기준으로 그보다 거리가 가까운 경우의 모든 두 점을 서로 연결합니다. 두번째는 자기 자신과 가장 가까운 K개의 점을 연결하는 KNN방식으로 모든 점들을 서로 연결합니다. 첫번째 단계가 진행된 후의 모습은 그림3과 같습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/wfCi78s.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림3 &lt;/p&gt;

&lt;p&gt;이렇게 점끼리 연결되었을때 엣지의 가중치는 두 연결된 점 사이의 유클리디안 거리가 됩니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;2.	두 점 간의 최단 경로 그래프 계산&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;두 점 i와 j에 대하여 두 점이 서로 연결되어 있다면 $d_G(i,j) = d_X(i,j)$
서로 연결되어 있지 않으면 $d_G(i,j)$를 무한으로 초기화합니다. 
그 후 1부터 N개 까지의 k에 있어서 점 i와 j간의 최단 거리를 의미하는 $d_G(i,j)$를 $min(d_G(i,j),d_g(i,k)+d_G(k,i))$로 변환합니다. 이 과정에서는 대표적으로 Dijkstra 알고리즘, Floyd’s 알고리즘 등이 사용됩니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;3.	MDS 방법론을 사용하여 d차원 임베딩 구축&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;$\lambda_p $ 를 행렬 $ \tau(D_G) $ 의 p번째 eigenvalue, $v^i_p$를 p-th eigen vector의 i번째 컴포넌트라 할 때, d차원의 p번째 컴포넌트에 대한 coordinate vector $y_i$를 $\sqrt{(\lambda_p)}v_p^i$로 두어 d차원의 임베딩을 구축합니다. 여기서 MDS, 즉 다차원 척도 구성법이란 거리 데이터만 주어졌을 때 그 거리를 재현하는 것처럼 좌표계를 역산하는 방법을 의미합니다.&lt;/p&gt;

&lt;p&gt;Isomap의 알고리즘은 크게 출력의 차원과 이웃사이즈(K) 또는 epsilon의 하이퍼 파라미터를 갖습니다. 차원의 선택은 시각화를 위해 많은 경우 2차원 또는 3차원을 사용하며 이웃의 크기를 선택하는 것은 이에 비해 더욱 복잡합니다. 만약 이웃의 수가 너무 작으면 측지선이 데이터간의 정보를 잘 담지 못하고 k가 너무 크다면 그래프 내에서 잘못된 연결이 나타나기 때문입니다.&lt;/p&gt;

&lt;p&gt;Isomap의 계산복잡도는 아래와 같습니다.&lt;/p&gt;

\[O[D \log(k) N \log(N)] + O[N^2(k + \log(N))] + O[d N^2]\]

&lt;ul&gt;
  &lt;li&gt;N: 훈련 데이터 포인트의 수&lt;/li&gt;
  &lt;li&gt;D: 입력 차원수&lt;/li&gt;
  &lt;li&gt;k: 가장 가까운 이웃의 수&lt;/li&gt;
  &lt;li&gt;d: 출력 차원수&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;그림 4는 손글씨 2에 해당하는 MNIST데이터를 Isomap을 활용하여 2차원으로 축소한 결과입니다. 아래로 내려갈수록 숫자 2의 윗부분에 동그랗게 그려지고 오른쪽으로 갈수록 숫자 2의 아랫부분이 동그랗게 그려지는 특징을 정확히 잡아냈음을 확인할 수 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/FwctPtZ.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림4 &lt;/p&gt;

&lt;h3 id=&quot;isomap을-활용한-예시&quot;&gt;Isomap을 활용한 예시&lt;/h3&gt;

&lt;p&gt;MNIST 7개의 손글씨 숫자 Dataset을 Isomap에 활용하여 차원 축소&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;Dijkstra 알고리즘 생성&lt;/li&gt;
&lt;/ul&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Graph&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;object&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nodes&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;set&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;edges&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{}&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{}&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;add_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nodes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;add_edge&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;from_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;to_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distance&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_add_edge&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;from_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;to_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distance&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_add_edge&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;from_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distance&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_add_edge&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;from_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;to_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distance&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;edges&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;setdefault&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;from_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[])&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;edges&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;from_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;from_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;to_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distance&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;dijkstra&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;graph&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;initial_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;visited_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;initial_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;nodes&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;set&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;graph&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nodes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;while&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nodes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;connected_node&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;node&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nodes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;node&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;visited_dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;connected_node&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                    &lt;span class=&quot;n&quot;&gt;connected_node&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;node&lt;/span&gt;
                &lt;span class=&quot;k&quot;&gt;elif&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;visited_dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;visited_dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;connected_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt;
                    &lt;span class=&quot;n&quot;&gt;connected_node&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;node&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;connected_node&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;break&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;nodes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;remove&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;connected_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;cur_wt&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;visited_dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;connected_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;edge&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;graph&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;edges&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;connected_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;wt&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cur_wt&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;graph&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;connected_node&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;edge&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;edge&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;visited_dist&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;or&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wt&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;visited_dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;edge&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;visited_dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;edge&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wt&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;visited_dist&lt;/span&gt;
    
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;ul&gt;
  &lt;li&gt;Isomap 함수 생성&lt;/li&gt;
&lt;/ul&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;sklearn.neighbors&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NearestNeighbors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kneighbors_graph&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;sklearn.decomposition&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;KernelPCA&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;dijkstra&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Graph&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dijkstra&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;pickle&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;isomap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_neighbors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_components&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_jobs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;distance_matrix&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;pickle&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;load&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;open&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'./isomap_distance_matrix.p'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'rb'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kernel_pca_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;KernelPCA&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_components&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_components&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                 &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;precomputed&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                 &lt;span class=&quot;n&quot;&gt;eigen_solver&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'arpack'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;max_iter&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                 &lt;span class=&quot;n&quot;&gt;n_jobs&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_jobs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distance_matrix&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;embedding&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_pca_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fit_transform&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;ul&gt;
  &lt;li&gt;MNIST 데이터 활용하여 모델 적용&lt;/li&gt;
&lt;/ul&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;sklearn&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;datasets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random_projection&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;time&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;matplotlib&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offsetbox&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;mpl_toolkits.mplot3d&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Axes3D&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;isomap&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;isomap&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;digits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;datasets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;load_digits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_class&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;7&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;digits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;digits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;n_samples&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_features&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;n_neighbors&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;30&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;plot_embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;title&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x_min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_max&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_max&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;figure&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ax&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;subplot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;111&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;digits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]),&lt;/span&gt;
                 &lt;span class=&quot;n&quot;&gt;color&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Set1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;10.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
                 &lt;span class=&quot;n&quot;&gt;fontdict&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'weight'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'bold'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'size'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;9&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;})&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;hasattr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;offsetbox&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'AnnotationBbox'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;shown_images&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]])&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# just something big
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;digits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]):&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;shown_images&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;4e-3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;k&quot;&gt;continue&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;shown_images&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shown_images&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]]&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;imagebox&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offsetbox&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;AnnotationBbox&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;offsetbox&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;OffsetImage&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;digits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;images&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cmap&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gray_r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;ax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_artist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;imagebox&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xticks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([]),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;yticks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([])&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;title&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;title&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;title&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;n_img_per_row&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;30&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;img&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_img_per_row&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_img_per_row&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_img_per_row&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ix&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_img_per_row&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;iy&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ix&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ix&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;iy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;iy&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_img_per_row&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;imshow&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cmap&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;binary&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xticks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([])&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;yticks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([])&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;title&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'A selection from the 64-dimensional digits dataset'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;rp&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;random_projection&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;SparseRandomProjection&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_components&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;random_state&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;42&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;X_projected&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fit_transform&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plot_embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_projected&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;Random Projection of the digits&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;t0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;embedding&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;isomap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_neighbors&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_neighbors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_components&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_jobs&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;plot_embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Isomap projection of the digits (time %.2fs)&quot;&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;embedding_three_dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;isomap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_neighbors&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_neighbors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_components&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_jobs&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;fig&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;figure&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;ax&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Axes3D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;ax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scatter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embedding_three_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;embedding_three_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;embedding_three_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;ul&gt;
  &lt;li&gt;최종 시각화 결과 확인&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;총 7개의 숫자가 2차원에서도 잘 분류되었고 3차원에서도 각 색깔이 뚜렷하게 군집을 이루고 있음을 확인할 수 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/lpviFTG.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림5 &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/J9Rdi7l.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림6 &lt;/p&gt;

&lt;hr /&gt;

&lt;h2 id=&quot;lle&quot;&gt;LLE&lt;/h2&gt;

&lt;p&gt;로컬 선형 임베딩(Local Linear Embedding)은 고차원의 공간에서 인접해 있는 데이터들 사이의 선형적 구조를 보존하면서 저차원으로 임베딩하는 방법론입니다. 즉 좁은 범위에서 구축한 선형모델을 연결하면 다양체, 매니폴드를 잘 표현할 수 있다는 알고리즘입니다. LLE는 다음과 같은 장점을 갖습니다.&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;사용하기에 간단하다.&lt;/li&gt;
  &lt;li&gt;최적화가 국소최소점으로 가지 않는다.&lt;/li&gt;
  &lt;li&gt;비선형 임베딩 생성이 가능하다.&lt;/li&gt;
  &lt;li&gt;고차원의 데이터를 저차원의 데이터로 매핑이 가능하다.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;LLE 알고리즘은 3 단계로 구성됩니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;1.	가장 가까운 이웃 검색&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;각 데이터 포인트 점에서 k개의 이웃을 구합니다.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;2.	가중치 매트릭스 구성&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;현재의 데이터를 나머지 k개의 데이터의 가중치의 합을 뺄 때 최소가 되는 가중치 매트릭스를 구합니다.&lt;/p&gt;

\[E(W) = \sum_i \left|x_i - \sum_j W_{ij} x_j\right|^2\]

&lt;p&gt;s.t. $ W_{ij} = 0 $ if $ x_j $ 가 $x_i$의 이웃에 속하지 않을때 모든 i 에 대하여 $\sum_j W_{ij} = 1$&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;3.	부분 고유치 분해&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;앞서 구한 가중치를 최대한 보장하며 차원을 축소합니다. 이때 차원 축소된 후의 점을 Y로 표현하며 차원 축소된 $ Y_j $와의 값 차이를 최소화하는 Y를 찾습니다.&lt;/p&gt;

\[\Phi(W) = \sum_i \left| y_i - \sum_j W_ik y_j \right|^2\]

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/SBVKuSc.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림7 &lt;/p&gt;

&lt;p&gt;그림 7는 LLE학습 과정을 나타냈습니다. 가중치와 벡터는 비록 선형대수의 방법으로 계산되지만 점들이 이웃 점들에게서만 재구축된다는 조건은 비선형 임베딩 결과를 초래하기에 nonlinear mapping으로 간주됩니다. 
LLE의 계산복잡도는 아래와 같습니다.&lt;/p&gt;

\[O[D \log(k) N \log(N)] + O[D N K^3] + O[d N^2]\]

&lt;ul&gt;
  &lt;li&gt;N: 훈련 데이터 포인트의 수&lt;/li&gt;
  &lt;li&gt;D: 입력 차원수&lt;/li&gt;
  &lt;li&gt;k: 가장 가까운 이웃의 수&lt;/li&gt;
  &lt;li&gt;d: 출력 차원수&lt;/li&gt;
&lt;/ul&gt;

&lt;h3 id=&quot;lle을-활용한-예시&quot;&gt;LLE을 활용한 예시&lt;/h3&gt;

&lt;ul&gt;
  &lt;li&gt;1000개의 스위스롤을 구성하는 데이터를 2차원으로 차원 축소&lt;/li&gt;
&lt;/ul&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;pylab&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;pl&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;swissroll&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1000&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;noise&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.05&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pi&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;rand&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;h&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;21&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;rand&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;concatenate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;h&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;noise&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;squeeze&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;LLE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nRedDim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;K&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;12&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ndata&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ndim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ndata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ndata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ndata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ndata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ndim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sqrt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;argsort&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;neighbours&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;K&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;K&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ndata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ndata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;neighbours&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kron&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ones&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;K&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:])&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;C&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;C&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;C&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;identity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;K&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e-3&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linalg&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;solve&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ones&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;K&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;M&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;eye&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ndata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ndata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ones&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;neighbours&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;ww&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;K&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;M&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;M&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;K&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;M&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ww&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;evals&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;evecs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linalg&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;eig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;M&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ind&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;argsort&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;evals&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;evecs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ind&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nRedDim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sqrt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ndata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;evals&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;evecs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;swissroll&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;evals&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;evecs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LLE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;t2&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;t3&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t3&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;pl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scatter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;50&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cmap&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gray&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;pl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'off'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;pl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;show&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;ul&gt;
  &lt;li&gt;스위스롤 데이터의 차원 축소 시각화 결과&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;검은색과 흰색으로 색깔에 따라 데이터의 특징을 잘 보존하며 차원 축소가 이루어 졌음을 확인할 수 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/ZXznmRR.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림8 &lt;/p&gt;

&lt;hr /&gt;

&lt;h2 id=&quot;t-sne&quot;&gt;t-SNE&lt;/h2&gt;

&lt;p&gt;t-SNE 전에  SNE(Stochastic Neighbor Embedding)부터 설명하겠습니다.
SNE는 고차원 공간에서 유클리드 거리를 포인트들간의 유사성을 표현하는 조건부 확률로 변환하는 방법입니다. 두 점 i에 대해 j와의 유사도를 나타내는 조건부 확률은 i를 중심으로하는 가우시안 분포(정규 분포)의 밀도에 비례하여 근방이 선택되도록 하는 확률을 의미합니다. 즉 조건부 확률이 높다면 서로간의 유사성이 높아 포인트의 거리가 가깝고 반대일 경우에는 거리가 멀다고 해석할 수 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/7qATrbV.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림9 &lt;/p&gt;

&lt;p&gt;본격적으로 SNE알고리즘 계산을 위해 고차원 공간의 데이터 포인트간의 거리 정보를 보존하는 저차원 데이터 포인트를 정의합니다. 그림7은 고차원과 저차원에서 각 데이터 포인트끼리의 조건부 확률, 즉 유사도를 의미합니다.&lt;br /&gt;
만일 고차원의 데이터 포인트끼리의 거리 정보가 저차원의 포인트간에서도 잘 보존되었다면 그림 10의 $p_{j|i} $ 와 $ q_{j|i} $ 가 유사할 것입니다. 두 확률 분포의 유사도를 측정하는 지표로 KL-divergence(Kullback-Leibler divergence)가 있습니다. 최소 0에서 1까지의 값을 가지며 동일할수록 그 값이 낮습니다. 즉 그림11과 같이 모든 데이터 포인트에 대해서 KL divergence값의 총합을 최소화 하는 방향으로 학습이 진행되며 최소화는 gradient descent를 통해 수행됩니다.&lt;/p&gt;

&lt;table&gt;
  &lt;thead&gt;
    &lt;tr&gt;
      &lt;th&gt;고차원 데이터 포인트&lt;/th&gt;
      &lt;th&gt;저차원 데이터 포인트&lt;/th&gt;
    &lt;/tr&gt;
  &lt;/thead&gt;
  &lt;tbody&gt;
    &lt;tr&gt;
      &lt;td&gt;$ x_i, x_j$&lt;/td&gt;
      &lt;td&gt;$ y_i, y_j $&lt;/td&gt;
    &lt;/tr&gt;
  &lt;/tbody&gt;
&lt;/table&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/9By9iXh.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림10 &lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/TSohbs1.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림11 &lt;/p&gt;

&lt;p&gt;SNE는 앞서 언급했듯이 가우시안 분포를 가정합니다. 그런데 가우시안 분포는 양쪽 꼬리가 충분히 두텁지 않습니다. 즉 일정 거리 이상부터는 아주 멀리 떨어져 있어도 선택될 확률이 큰 차이가 나지 않게되는데 이를 Crowding Problem이라고 합니다. 이 단점을 완화하기 위해 가우시안 분포와 유사하지만 좀 더 양 끝이 두터운 자유도 1의 t분포를 사용합니다. 이것이 바로 t-SNE입니다. 그림12는 가우시안 분포와 t분포의 차이를 보여줍니다. SNE의 $p_{ij}$는 동일하게 사용하며 대신 $q_{ij}$에만 t분포를 적용합니다. t분포를 적용한 $q_{ij}$ 는 아래와 같습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/FYRDh4p.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림12 &lt;/p&gt;

&lt;p&gt;t-SNE의 장점은 PCA와는 달리 군집이 증복되지 않는다는 점입니다. 그렇기에 시각화에 굉장히 유용합니다. 또한 지역적인 구조를 잘 잡아내는 동시에 글로벌적 특징도 놓치지 않음이 알려져 있습니다. 아래 그림 13과 같이 각 숫자별 클러스터가 잘 형성되며 동시에 유사한 모습의 숫자인 7과 9의 위치가 굉장히 가까이 나타남을 확인할 수 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/ho78NYk.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림13 &lt;/p&gt;

&lt;p&gt;반면에 매 시도마다 임의로 데이터 포인트를 선정하기에 축의 위치가 계속해서 변해 모델의 학습용으로는 좋지 않습니다. 또한 계산 비용이 많이 들어 학습이 오래걸립니다. 같은 데이터에서도 PCA에 비해 크게 긴 계산 시간을 요구합니다.&lt;/p&gt;

&lt;h3 id=&quot;t-sne를-활용한-예시&quot;&gt;t-SNE를 활용한 예시&lt;/h3&gt;

&lt;p&gt;t-SNE를 활용하여 5개의 MNIST 글씨 데이터 2차원 축소 시각화&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;t-SNE 학습을 위한 함수들 생성&lt;/li&gt;
&lt;/ul&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;neg_distance&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sum_X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;square&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;D&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sum_X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sum_X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;D&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;e_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fill_diagonal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;e_x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;e_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e-8&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e_x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;calc_prob_matrix&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sigmas&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sigmas&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;two_sig_sq&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;2.&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;square&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sigmas&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;two_sig_sq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_perplexity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prob_matrix&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;entropy&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prob_matrix&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;log2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prob_matrix&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;perplexity&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;entropy&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;perplexity&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;perplexity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sigmas&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_perplexity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;calc_prob_matrix&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sigmas&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;binary_search&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tol&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1e-10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;max_iter&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10000&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                  &lt;span class=&quot;n&quot;&gt;lower&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1e-20&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;upper&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1000.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;max_iter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;guess&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lower&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;upper&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;2.&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;val&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;guess&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;val&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;upper&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;guess&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;lower&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;guess&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;abs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;val&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tol&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;break&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;guess&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;find_optimal_sigmas&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target_perplexity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sigmas&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sigma&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; \
            &lt;span class=&quot;n&quot;&gt;perplexity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sigma&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;correct_sigma&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;binary_search&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target_perplexity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;sigmas&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;correct_sigma&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sigmas&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;p_conditional_to_joint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;P&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;2.&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;q_joint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;inv_distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;neg_squared_euc_dists&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;power&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inv_distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fill_diagonal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;p_joint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target_perplexity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;neg_distance&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sigmas&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;find_optimal_sigmas&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target_perplexity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;p_conditional&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;calc_prob_matrix&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sigmas&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;P&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p_conditional_to_joint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;p_conditional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;P&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;tsne_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pq_diff&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;P&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Q&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pq_expanded&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pq_diff&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;y_diffs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;distances_expanded&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;y_diffs_wt&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_diffs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distances_expanded&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;4.&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pq_expanded&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_diffs_wt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;t_SNE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_component&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_iters&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;500&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;learning_rate&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;10.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;momentum&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;9&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;normal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.0001&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_component&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;P&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p_joint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;20&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;momentum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;Y_m2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;Y_m1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_iters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;Q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q_joint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;grads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tsne_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;learning_rate&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grads&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;momentum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;momentum&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Y_m1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y_m2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;Y_m2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y_m1&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;Y_m1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;ul&gt;
  &lt;li&gt;MNIST Datasets을 활용한 t-SNE 학습&lt;/li&gt;
&lt;/ul&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;sklearn&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;datasets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;time&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;matplotlib&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offsetbox&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;tsne&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t_SNE&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;digits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;datasets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;load_digits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_class&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;digits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;digits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;n_samples&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_features&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;plot_embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;title&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x_min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_max&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_max&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;figure&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ax&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;subplot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;111&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;digits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]),&lt;/span&gt;
                 &lt;span class=&quot;n&quot;&gt;color&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Set1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;10.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
                 &lt;span class=&quot;n&quot;&gt;fontdict&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'weight'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'bold'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'size'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;9&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;})&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;hasattr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;offsetbox&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'AnnotationBbox'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;shown_images&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]])&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# just something big
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;digits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]):&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;shown_images&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;4e-3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;k&quot;&gt;continue&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;shown_images&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shown_images&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]]&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;imagebox&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offsetbox&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;AnnotationBbox&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;offsetbox&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;OffsetImage&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;digits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;images&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cmap&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gray_r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;ax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_artist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;imagebox&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;xticks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([]),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;yticks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([])&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;title&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;plt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;title&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;title&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Computing t_SNE embedding&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;t0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;embedding&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t_SNE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_component&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;plot_embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;t_SNE projection of the digits (time %.2fs)&quot;&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;ul&gt;
  &lt;li&gt;학습된 데이터를 활용한 시각화 결과&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;각 숫자별 분류가 잘 되었고 비슷한 모양의 ‘2’와 ‘3’이 인접함을 알 수 있습니다.&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;&lt;img width=&quot;500&quot; height=&quot;auto&quot; src=&quot;https://i.imgur.com/6Wz4nNi.png&quot; /&gt;&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt; 그림14 &lt;/p&gt;
</description>
        <pubDate>Sat, 20 Oct 2018 00:00:00 +0000</pubDate>
        <link>https://woosikyang.github.io/first-post.html</link>
        <guid isPermaLink="true">https://woosikyang.github.io/first-post.html</guid>
        
        <category>Dimensionality Reduction</category>
        
        
        <category>Dimensionality</category>
        
        <category>Reduction</category>
        
      </item>
    
  </channel>
</rss>