From 240087fd56289a20058fd1a37dfaf1879a57a352 Mon Sep 17 00:00:00 2001 From: Jens Jasche Date: Mon, 10 Jun 2024 13:53:47 +0200 Subject: [PATCH] updated script --- Training_IC_on_Demand.ipynb | 64 +++++++++++++++++++++++++++---------- models/trainer.py | 2 +- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/Training_IC_on_Demand.ipynb b/Training_IC_on_Demand.ipynb index e375705..ae11531 100644 --- a/Training_IC_on_Demand.ipynb +++ b/Training_IC_on_Demand.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "f7c63b6e", + "id": "0ca17f2a", "metadata": {}, "source": [ "# Create ICs on Demand" @@ -10,8 +10,8 @@ }, { "cell_type": "code", - "execution_count": 7, - "id": "7fefd6b2", + "execution_count": 1, + "id": "c5970fbc", "metadata": {}, "outputs": [], "source": [ @@ -26,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "e5aa8c20", + "id": "dff1b4d5", "metadata": {}, "source": [ "# Setup the model" @@ -34,18 +34,16 @@ }, { "cell_type": "code", - "execution_count": 8, - "id": "3d84fb99", + "execution_count": 2, + "id": "3b06a811", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "(32, 32, 32)\n", "Trainer model: tm\n", - "load initial state\n", - "save initial state\n" + "load initial state\n" ] } ], @@ -61,7 +59,7 @@ }, { "cell_type": "markdown", - "id": "476e8808", + "id": "eb5f1e83", "metadata": {}, "source": [ "# test training with white noise" @@ -69,8 +67,8 @@ }, { "cell_type": "code", - "execution_count": 11, - "id": "bba2858b", + "execution_count": 6, + "id": "3bdfe2a8", "metadata": {}, "outputs": [ { @@ -78,7 +76,6 @@ "output_type": "stream", "text": [ "Train model...\n", - "test [[[3.14149864]]]\n", "Training done\n" ] }, @@ -88,22 +85,57 @@ "0" ] }, - "execution_count": 11, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "for i in np.arange(100):\n", + "for i in np.arange(1000):\n", " x_train = np.random.normal(0,0.01,shape)+3.1415\n", " tm.train_single(x_train, silent=True)\n", "tm.transfer(silent=False)" ] }, + { + "cell_type": "code", + "execution_count": 7, + "id": "65a41e1c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.006170461447273722 1.006313544079366\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQTUlEQVR4nO3df4ylVX3H8feni2DV1kWZULq76W7ixoZSG80ESUha4loFJC5/IIG2uiLNpgm2WG0U9A+StiYYGxFTa7MB6toSkaCGjWJ1ixDTP6AMP6QCohMUdzcgoyCaEmu3fvvHPWuvy+yPmTtz7+yc9yuZ7POc59z7nHky+7lnznOeM6kqJEl9+JVJN0CSND6GviR1xNCXpI4Y+pLUEUNfkjpy3KQbcDgnnXRSbdy4cdLNkKRjyr333vuDqpqa79iKDv2NGzcyMzMz6WZI0jElyeOHOubwjiR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdWRFP5ErrSYbr/jiL7a/e/WbJtgS9cyeviR1xNCXpI4cMfST3JDkqSTfGCr7cJJvJnkwyeeTrB06dmWS2SSPJnnjUPnZrWw2yRVL/p1Iko7oaHr6nwTOPqhsN3BaVb0K+BZwJUCSU4GLgN9pr/mHJGuSrAE+DpwDnApc3OpKksboiDdyq+prSTYeVPaVod27gAva9lbgpqr6b+A7SWaB09ux2ap6DCDJTa3uw6M1X1q9vPGr5bAUY/rvAL7UttcBe4aO7W1lhyp/niTbk8wkmZmbm1uC5kmSDhgp9JN8ANgP3Lg0zYGq2lFV01U1PTU17x9+kSQt0qLn6Sd5O3AesKWqqhXvAzYMVVvfyjhMuSRpTBbV009yNvBe4M1V9dzQoV3ARUlOSLIJ2Az8B3APsDnJpiTHM7jZu2u0pkuSFuqIPf0knwbOAk5Kshe4isFsnROA3UkA7qqqP6uqh5LczOAG7X7gsqr63/Y+7wS+DKwBbqiqh5bh+5EkHcbRzN65eJ7i6w9T/4PAB+cpvw24bUGtk45BzrrRSubaO9IyGv4AOFS5HwwaJ5dhkKSO2NOXJuxQvw1Iy8GeviR1xNCXpI4Y+pLUEUNfkjpi6EtSR5y9Ix0DnNevpWJPX5I6Yk9fWiTn1+tYZE9fkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcQpm9Ixxge1NAp7+pLUEXv60gL4QJaOdfb0Jakjhr4kdcTQl6SOGPqS1BFv5EqrhFM5dTSO2NNPckOSp5J8Y6jsZUl2J/l2+/fEVp4kH0sym+TBJK8Zes22Vv/bSbYtz7cjSTqcoxne+SRw9kFlVwC3V9Vm4Pa2D3AOsLl9bQc+AYMPCeAq4LXA6cBVBz4oJEnjc8TQr6qvAU8fVLwV2Nm2dwLnD5V/qgbuAtYmOQV4I7C7qp6uqmeA3Tz/g0SStMwWeyP35Kp6om0/CZzcttcBe4bq7W1lhyp/niTbk8wkmZmbm1tk8yRJ8xl59k5VFVBL0JYD77ejqqaranpqamqp3laSxOJn73w/ySlV9UQbvnmqle8DNgzVW9/K9gFnHVR+5yLPLalxWQgt1GJ7+ruAAzNwtgG3DpW/rc3iOQN4tg0DfRl4Q5IT2w3cN7QySdIYHbGnn+TTDHrpJyXZy2AWztXAzUkuBR4HLmzVbwPOBWaB54BLAKrq6SR/A9zT6v11VR18c1iStMyOGPpVdfEhDm2Zp24Blx3ifW4AblhQ6yRJS8plGCSpIy7DIB2BN0u1mhj60irkOjw6FId3JKkjhr4kdcTQl6SOGPqS1BFDX5I64uwdaZU7eMqps3n6Zk9fkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOOE9fmofLKWu1sqcvSR0x9CWpI4a+JHXE0Jekjhj6ktQRZ+9IjTN21ANDX+qMfzS9byMN7yT5yyQPJflGkk8neWGSTUnuTjKb5DNJjm91T2j7s+34xiX5DiRJR23RoZ9kHfAXwHRVnQasAS4CPgRcU1WvAJ4BLm0vuRR4ppVf0+pJksZo1Bu5xwG/muQ44EXAE8DrgFva8Z3A+W17a9unHd+SJCOeX5K0AIsO/araB/wd8D0GYf8scC/wo6ra36rtBda17XXAnvba/a3+yw9+3yTbk8wkmZmbm1ts8yRJ8xhleOdEBr33TcBvAi8Gzh61QVW1o6qmq2p6ampq1LeTJA0ZZXjn9cB3qmquqv4H+BxwJrC2DfcArAf2te19wAaAdvylwA9HOL8kaYFGCf3vAWckeVEbm98CPAzcAVzQ6mwDbm3bu9o+7fhXq6pGOL8kaYEWPU+/qu5OcgtwH7AfuB/YAXwRuCnJ37ay69tLrgf+Ocks8DSDmT7SRPlAlnoz0sNZVXUVcNVBxY8Bp89T96fAW0Y5nyRpNK69I0kdMfQlqSOuvSN1zHV4+mNPX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0JekjjhPX91xvR31zJ6+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xCdyJQH+Fa1e2NOXpI4Y+pLUEYd31AUXWZMGRgr9JGuB64DTgALeATwKfAbYCHwXuLCqnkkS4FrgXOA54O1Vdd8o55e0PBzfX71GHd65FvjXqvpt4PeAR4ArgNurajNwe9sHOAfY3L62A58Y8dySpAVadOgneSnw+8D1AFX1s6r6EbAV2Nmq7QTOb9tbgU/VwF3A2iSnLPb8kqSFG6WnvwmYA/4pyf1JrkvyYuDkqnqi1XkSOLltrwP2DL1+byuTJI3JKKF/HPAa4BNV9Wrgv/j/oRwAqqoYjPUftSTbk8wkmZmbmxuheZKkg40S+nuBvVV1d9u/hcGHwPcPDNu0f59qx/cBG4Zev76V/ZKq2lFV01U1PTU1NULzJEkHW3ToV9WTwJ4kr2xFW4CHgV3Atla2Dbi1be8C3paBM4Bnh4aBJEljMOo8/T8HbkxyPPAYcAmDD5Kbk1wKPA5c2OrexmC65iyDKZuXjHhuSdICjRT6VfUAMD3PoS3z1C3gslHOJ0kajcswSFJHDH1J6oihL0kdccE1rVousiY9nz19SeqIoS9JHXF4R9Jhuczy6mJPX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXEKZtaNXwCVzoye/qS1BFDX5I64vCOpKPm07nHPnv6ktQRQ1+SOmLoS1JHHNPXMc1pmtLC2NOXpI4Y+pLUEUNfkjpi6EtSR0YO/SRrktyf5Attf1OSu5PMJvlMkuNb+Qltf7Yd3zjquSVJC7MUPf3LgUeG9j8EXFNVrwCeAS5t5ZcCz7Tya1o9SceojVd88RdfOnaMFPpJ1gNvAq5r+wFeB9zSquwEzm/bW9s+7fiWVl+SNCaj9vQ/CrwX+Hnbfznwo6ra3/b3Auva9jpgD0A7/myr/0uSbE8yk2Rmbm5uxOZJkoYtOvSTnAc8VVX3LmF7qKodVTVdVdNTU1NL+daS1L1Rnsg9E3hzknOBFwK/DlwLrE1yXOvNrwf2tfr7gA3A3iTHAS8FfjjC+dUpx5ClxVt0T7+qrqyq9VW1EbgI+GpV/TFwB3BBq7YNuLVt72r7tONfrapa7PklSQu3HPP03we8O8ksgzH761v59cDLW/m7gSuW4dySpMNYkgXXqupO4M62/Rhw+jx1fgq8ZSnOp/44pCMtDZ/IlaSOuLSypJH5ZxSPHfb0Jakjhr4kdcThHUlLyqGelc2eviR1xNCXpI4Y+pLUEUNfkjrijVytWD6FKy09e/qS1BFDX5I6YuhLUkcMfUnqiKEvSR1x9o6kZeOSDCuPPX1J6oihL0kdMfQlqSOGviR1xBu5WlFcekFaXvb0Jakjhr4kdcThHUlj4Zz9lWHRPf0kG5LckeThJA8lubyVvyzJ7iTfbv+e2MqT5GNJZpM8mOQ1S/VNSJKOzijDO/uB91TVqcAZwGVJTgWuAG6vqs3A7W0f4Bxgc/vaDnxihHNLkhZh0cM7VfUE8ETb/kmSR4B1wFbgrFZtJ3An8L5W/qmqKuCuJGuTnNLeRx1zxo40PktyIzfJRuDVwN3AyUNB/iRwctteB+wZetneVnbwe21PMpNkZm5ubimaJ0lqRg79JC8BPgu8q6p+PHys9eprIe9XVTuqarqqpqempkZtniRpyEizd5K8gEHg31hVn2vF3z8wbJPkFOCpVr4P2DD08vWtTFLHnNUzXqPM3glwPfBIVX1k6NAuYFvb3gbcOlT+tjaL5wzgWcfzJWm8Runpnwm8FfjPJA+0svcDVwM3J7kUeBy4sB27DTgXmAWeAy4Z4dw6xnnzVpqMUWbv/DuQQxzeMk/9Ai5b7PkkrR5+6E+OyzBIUkcMfUnqiKEvSR1xwTWNjeO40uTZ05ekjhj6ktQRQ1+SOuKYvqQV41D3fVyeYenY05ekjhj6ktQRh3e0rJymKa0shr6WhMvjajn587V0HN6RpI7Y09eSc0hHWrns6UtSRwx9SeqIwzuSjlne4F04Q18L4ni9Js2fwdE4vCNJHbGnL2lVcKjn6Bj6OiJ/nZZWD4d3JKkj9vQlrToO9RyaoS+pG34YQKpq0m04pOnp6ZqZmZl0M1Y1x+ulgdX0IZDk3qqanu/Y2Hv6Sc4GrgXWANdV1dXjboMkHayX3wLGGvpJ1gAfB/4Q2Avck2RXVT08znb0wl68tDir+QNg3D3904HZqnoMIMlNwFag69Bf6A+YYS6Nz9H83d6j+T+50P/by/VhM9Yx/SQXAGdX1Z+2/bcCr62qdw7V2Q5sb7uvBB4dWwPH4yTgB5NuxArgdRjwOgx4HQaW6jr8VlVNzXdgxc3eqaodwI5Jt2O5JJk51A2WnngdBrwOA16HgXFch3E/nLUP2DC0v76VSZLGYNyhfw+wOcmmJMcDFwG7xtwGSerWWId3qmp/kncCX2YwZfOGqnponG1YAVbt0NUCeR0GvA4DXoeBZb8OK/rhLEnS0nLBNUnqiKEvSR0x9CcoyXuSVJKTJt2WSUjy4STfTPJgks8nWTvpNo1TkrOTPJpkNskVk27PJCTZkOSOJA8neSjJ5ZNu06QkWZPk/iRfWM7zGPoTkmQD8Abge5NuywTtBk6rqlcB3wKunHB7xmZoSZJzgFOBi5OcOtlWTcR+4D1VdSpwBnBZp9cB4HLgkeU+iaE/OdcA7wW6vZNeVV+pqv1t9y4Gz2304hdLklTVz4ADS5J0paqeqKr72vZPGITeusm2avySrAfeBFy33Ocy9CcgyVZgX1V9fdJtWUHeAXxp0o0Yo3XAnqH9vXQYdsOSbAReDdw94aZMwkcZdAJ/vtwnWnHLMKwWSf4N+I15Dn0AeD+DoZ1V73DXoapubXU+wODX/BvH2TatHEleAnwWeFdV/XjS7RmnJOcBT1XVvUnOWu7zGfrLpKpeP195kt8FNgFfTwKDIY37kpxeVU+OsYljcajrcECStwPnAVuqr4dGXJKkSfICBoF/Y1V9btLtmYAzgTcnORd4IfDrSf6lqv5kOU7mw1kTluS7wHRVdbfCYPuDOh8B/qCq5ibdnnFKchyDm9dbGIT9PcAf9faEegY9n53A01X1rgk3Z+JaT/+vquq85TqHY/qapL8Hfg3YneSBJP846QaNS7uBfWBJkkeAm3sL/OZM4K3A69rPwAOtx6tlYk9fkjpiT1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI78H+SbaGw0YK3IAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "x = model.generate()\n", + "print(np.mean(x[nlevel].flatten()),np.std(x[nlevel].flatten()))\n", + "\n", + "\n", + "plt.hist(x[nlevel].flatten(),bins=100)\n", + "plt.show()" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "75b0b971", + "id": "be69ddd5", "metadata": {}, "outputs": [], "source": [] diff --git a/models/trainer.py b/models/trainer.py index aaf6fa1..db55902 100644 --- a/models/trainer.py +++ b/models/trainer.py @@ -112,7 +112,7 @@ class trainer: self.gen_model.model[0].b = self.model[0].mean self.gen_model.model[0].w = np.sqrt(self.model[0].var) - print('test',self.gen_model.model[0].b) + #print('test',self.gen_model.model[0].b) if( not silent): print('Training done')